From 7a9bf5f9f0717cb316b40a83484fb51c2cf0ea89 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Wed, 10 May 2023 14:02:02 -0700 Subject: [PATCH 01/37] Blacken --- tests/test_connection.py | 526 ++++++++++++++++------------ trio_websocket/_impl.py | 699 ++++++++++++++++++++++--------------- trio_websocket/_version.py | 2 +- 3 files changed, 727 insertions(+), 500 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 79bb9b4..1f29279 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,4 @@ -''' +""" Unit tests for trio_websocket. Many of these tests involve networking, i.e. real TCP sockets. To maximize @@ -28,7 +28,7 @@ call ``ws.get_message()`` without actually sending it a message. This will cause the server to block until the client has sent the closing handshake. In other circumstances -''' +""" from functools import partial, wraps import ssl from unittest.mock import patch @@ -61,13 +61,13 @@ WebSocketServer, WebSocketRequest, wrap_client_stream, - wrap_server_stream + wrap_server_stream, ) -WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.'))) +WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split("."))) -HOST = '127.0.0.1' -RESOURCE = '/resource' +HOST = "127.0.0.1" +RESOURCE = "/resource" DEFAULT_TEST_MAX_DURATION = 1 # Timeout tests follow a general pattern: one side waits TIMEOUT seconds for an @@ -81,27 +81,25 @@ @pytest.fixture async def echo_server(nursery): - ''' A server that reads one message, sends back the same message, - then closes the connection. ''' - serve_fn = partial(serve_websocket, echo_request_handler, HOST, 0, - ssl_context=None) + """A server that reads one message, sends back the same message, + then closes the connection.""" + serve_fn = partial(serve_websocket, echo_request_handler, HOST, 0, ssl_context=None) server = await nursery.start(serve_fn) yield server @pytest.fixture async def echo_conn(echo_server): - ''' Return a client connection instance that is connected to an echo - server. ''' - async with open_websocket(HOST, echo_server.port, RESOURCE, - use_ssl=False) as conn: + """Return a client connection instance that is connected to an echo + server.""" + async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) as conn: yield conn async def echo_request_handler(request): - ''' + """ Accept incoming request and then pass off to echo connection handler. - ''' + """ conn = await request.accept() try: msg = await conn.get_message() @@ -111,8 +109,9 @@ async def echo_request_handler(request): class fail_after: - ''' This decorator fails if the runtime of the decorated function (as - measured by the Trio clock) exceeds the specified value. ''' + """This decorator fails if the runtime of the decorated function (as + measured by the Trio clock) exceeds the specified value.""" + def __init__(self, seconds): self._seconds = seconds @@ -122,7 +121,10 @@ async def wrapper(*args, **kwargs): with trio.move_on_after(self._seconds) as cancel_scope: await fn(*args, **kwargs) if cancel_scope.cancelled_caught: - pytest.fail(f'Test runtime exceeded the maximum {self._seconds} seconds') + pytest.fail( + f"Test runtime exceeded the maximum {self._seconds} seconds" + ) + return wrapper @@ -154,41 +156,41 @@ async def aclose(self): async def test_endpoint_ipv4(): - e1 = Endpoint('10.105.0.2', 80, False) - assert e1.url == 'ws://10.105.0.2' + e1 = Endpoint("10.105.0.2", 80, False) + assert e1.url == "ws://10.105.0.2" assert str(e1) == 'Endpoint(address="10.105.0.2", port=80, is_ssl=False)' - e2 = Endpoint('127.0.0.1', 8000, False) - assert e2.url == 'ws://127.0.0.1:8000' + e2 = Endpoint("127.0.0.1", 8000, False) + assert e2.url == "ws://127.0.0.1:8000" assert str(e2) == 'Endpoint(address="127.0.0.1", port=8000, is_ssl=False)' - e3 = Endpoint('0.0.0.0', 443, True) - assert e3.url == 'wss://0.0.0.0' + e3 = Endpoint("0.0.0.0", 443, True) + assert e3.url == "wss://0.0.0.0" assert str(e3) == 'Endpoint(address="0.0.0.0", port=443, is_ssl=True)' async def test_listen_port_ipv6(): - e1 = Endpoint('2599:8807:6201:b7:16cf:bb9c:a6d3:51ab', 80, False) - assert e1.url == 'ws://[2599:8807:6201:b7:16cf:bb9c:a6d3:51ab]' - assert str(e1) == 'Endpoint(address="2599:8807:6201:b7:16cf:bb9c:a6d3' \ - ':51ab", port=80, is_ssl=False)' - e2 = Endpoint('::1', 8000, False) - assert e2.url == 'ws://[::1]:8000' + e1 = Endpoint("2599:8807:6201:b7:16cf:bb9c:a6d3:51ab", 80, False) + assert e1.url == "ws://[2599:8807:6201:b7:16cf:bb9c:a6d3:51ab]" + assert ( + str(e1) == 'Endpoint(address="2599:8807:6201:b7:16cf:bb9c:a6d3' + ':51ab", port=80, is_ssl=False)' + ) + e2 = Endpoint("::1", 8000, False) + assert e2.url == "ws://[::1]:8000" assert str(e2) == 'Endpoint(address="::1", port=8000, is_ssl=False)' - e3 = Endpoint('::', 443, True) - assert e3.url == 'wss://[::]' + e3 = Endpoint("::", 443, True) + assert e3.url == "wss://[::]" assert str(e3) == 'Endpoint(address="::", port=443, is_ssl=True)' async def test_server_has_listeners(nursery): - server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, - None) + server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, None) assert len(server.listeners) > 0 assert isinstance(server.listeners[0], Endpoint) async def test_serve(nursery): task = current_task() - server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, - None) + server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, None) port = server.port assert server.port != 0 # The server nursery begins with one task (server.listen). @@ -209,11 +211,11 @@ async def test_serve_ssl(nursery): cert = ca.issue_server_cert(HOST) cert.configure_cert(server_context) - server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, - server_context) + server = await nursery.start( + serve_websocket, echo_request_handler, HOST, 0, server_context + ) port = server.port - async with open_websocket(HOST, port, RESOURCE, use_ssl=client_context - ) as conn: + async with open_websocket(HOST, port, RESOURCE, use_ssl=client_context) as conn: assert not conn.closed assert conn.local.is_ssl assert conn.remote.is_ssl @@ -222,8 +224,14 @@ async def test_serve_ssl(nursery): async def test_serve_handler_nursery(nursery): task = current_task() async with trio.open_nursery() as handler_nursery: - serve_with_nursery = partial(serve_websocket, echo_request_handler, - HOST, 0, None, handler_nursery=handler_nursery) + serve_with_nursery = partial( + serve_websocket, + echo_request_handler, + HOST, + 0, + None, + handler_nursery=handler_nursery, + ) server = await nursery.start(serve_with_nursery) port = server.port # The server nursery begins with one task (server.listen). @@ -248,7 +256,7 @@ async def test_serve_non_tcp_listener(nursery): assert len(server.listeners) == 1 with pytest.raises(RuntimeError): server.port # pylint: disable=pointless-statement - assert server.listeners[0].startswith('MemoryListener(') + assert server.listeners[0].startswith("MemoryListener(") async def test_serve_multiple_listeners(nursery): @@ -265,74 +273,77 @@ async def test_serve_multiple_listeners(nursery): assert server.listeners[0].port != 0 # The second listener metadata is a string containing the repr() of a # MemoryListener object. - assert server.listeners[1].startswith('MemoryListener(') + assert server.listeners[1].startswith("MemoryListener(") async def test_client_open(echo_server): - async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) \ - as conn: + async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) as conn: assert not conn.closed assert conn.is_client - assert str(conn).startswith('client-') + assert str(conn).startswith("client-") -@pytest.mark.parametrize('path, expected_path', [ - ('/', '/'), - ('', '/'), - (RESOURCE + '/path', RESOURCE + '/path'), - (RESOURCE + '?foo=bar', RESOURCE + '?foo=bar') -]) +@pytest.mark.parametrize( + "path, expected_path", + [ + ("/", "/"), + ("", "/"), + (RESOURCE + "/path", RESOURCE + "/path"), + (RESOURCE + "?foo=bar", RESOURCE + "?foo=bar"), + ], +) async def test_client_open_url(path, expected_path, echo_server): - url = f'ws://{HOST}:{echo_server.port}{path}' + url = f"ws://{HOST}:{echo_server.port}{path}" async with open_websocket_url(url) as conn: assert conn.path == expected_path async def test_client_open_invalid_url(echo_server): with pytest.raises(ValueError): - async with open_websocket_url('http://foo.com/bar') as conn: + async with open_websocket_url("http://foo.com/bar") as conn: pass async def test_ascii_encoded_path_is_ok(echo_server): - path = '%D7%90%D7%91%D7%90?%D7%90%D7%9E%D7%90' - url = f'ws://{HOST}:{echo_server.port}{RESOURCE}/{path}' + path = "%D7%90%D7%91%D7%90?%D7%90%D7%9E%D7%90" + url = f"ws://{HOST}:{echo_server.port}{RESOURCE}/{path}" async with open_websocket_url(url) as conn: - assert conn.path == RESOURCE + '/' + path + assert conn.path == RESOURCE + "/" + path -@patch('trio_websocket._impl.open_websocket') +@patch("trio_websocket._impl.open_websocket") def test_client_open_url_options(open_websocket_mock): """open_websocket_url() must pass its options on to open_websocket()""" port = 1234 - url = f'ws://{HOST}:{port}{RESOURCE}' + url = f"ws://{HOST}:{port}{RESOURCE}" options = { - 'subprotocols': ['chat'], - 'extra_headers': [(b'X-Test-Header', b'My test header')], - 'message_queue_size': 9, - 'max_message_size': 333, - 'connect_timeout': 36, - 'disconnect_timeout': 37, + "subprotocols": ["chat"], + "extra_headers": [(b"X-Test-Header", b"My test header")], + "message_queue_size": 9, + "max_message_size": 333, + "connect_timeout": 36, + "disconnect_timeout": 37, } open_websocket_url(url, **options) _, call_args, call_kwargs = open_websocket_mock.mock_calls[0] assert call_args == (HOST, port, RESOURCE) - assert not call_kwargs.pop('use_ssl') + assert not call_kwargs.pop("use_ssl") assert call_kwargs == options - open_websocket_url(url.replace('ws:', 'wss:')) + open_websocket_url(url.replace("ws:", "wss:")) _, call_args, call_kwargs = open_websocket_mock.mock_calls[1] - assert call_kwargs['use_ssl'] + assert call_kwargs["use_ssl"] async def test_client_connect(echo_server, nursery): - conn = await connect_websocket(nursery, HOST, echo_server.port, RESOURCE, - use_ssl=False) + conn = await connect_websocket( + nursery, HOST, echo_server.port, RESOURCE, use_ssl=False + ) assert not conn.closed async def test_client_connect_url(echo_server, nursery): - url = f'ws://{HOST}:{echo_server.port}{RESOURCE}' + url = f"ws://{HOST}:{echo_server.port}{RESOURCE}" conn = await connect_websocket_url(nursery, url) assert not conn.closed @@ -361,21 +372,21 @@ async def handler(request): conn = await request.accept() server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False - ) as client_ws: + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False) as client_ws: pass async def test_handshake_subprotocol(nursery): async def handler(request): - assert request.proposed_subprotocols == ('chat', 'file') - server_ws = await request.accept(subprotocol='chat') - assert server_ws.subprotocol == 'chat' + assert request.proposed_subprotocols == ("chat", "file") + server_ws = await request.accept(subprotocol="chat") + assert server_ws.subprotocol == "chat" server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - subprotocols=('chat', 'file')) as client_ws: - assert client_ws.subprotocol == 'chat' + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, subprotocols=("chat", "file") + ) as client_ws: + assert client_ws.subprotocol == "chat" async def test_handshake_path(nursery): @@ -385,8 +396,12 @@ async def handler(request): assert server_ws.path == RESOURCE server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - ) as client_ws: + async with open_websocket( + HOST, + server.port, + RESOURCE, + use_ssl=False, + ) as client_ws: assert client_ws.path == RESOURCE @@ -394,107 +409,118 @@ async def handler(request): async def test_handshake_client_headers(nursery): async def handler(request): headers = dict(request.headers) - assert b'x-test-header' in headers - assert headers[b'x-test-header'] == b'My test header' + assert b"x-test-header" in headers + assert headers[b"x-test-header"] == b"My test header" server_ws = await request.accept() - await server_ws.send_message('test') + await server_ws.send_message("test") server = await nursery.start(serve_websocket, handler, HOST, 0, None) - headers = [(b'X-Test-Header', b'My test header')] - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - extra_headers=headers) as client_ws: + headers = [(b"X-Test-Header", b"My test header")] + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, extra_headers=headers + ) as client_ws: await client_ws.get_message() @fail_after(1) async def test_handshake_server_headers(nursery): async def handler(request): - headers = [('X-Test-Header', 'My test header')] + headers = [("X-Test-Header", "My test header")] server_ws = await request.accept(extra_headers=headers) server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False - ) as client_ws: + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False) as client_ws: header_key, header_value = client_ws.handshake_headers[0] - assert header_key == b'x-test-header' - assert header_value == b'My test header' + assert header_key == b"x-test-header" + assert header_value == b"My test header" @fail_after(1) async def test_handshake_exception_before_accept(): - ''' In #107, a request handler that throws an exception before finishing the + """In #107, a request handler that throws an exception before finishing the handshake causes the task to hang. The proper behavior is to raise an - exception to the nursery as soon as possible. ''' + exception to the nursery as soon as possible.""" + async def handler(request): raise ValueError() with pytest.raises(ValueError): async with trio.open_nursery() as nursery: - server = await nursery.start(serve_websocket, handler, HOST, 0, - None) - async with open_websocket(HOST, server.port, RESOURCE, - use_ssl=False) as client_ws: + server = await nursery.start(serve_websocket, handler, HOST, 0, None) + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False + ) as client_ws: pass @fail_after(1) async def test_reject_handshake(nursery): async def handler(request): - body = b'My body' + body = b"My body" await request.reject(400, body=body) server = await nursery.start(serve_websocket, handler, HOST, 0, None) with pytest.raises(ConnectionRejected) as exc_info: - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - ) as client_ws: + async with open_websocket( + HOST, + server.port, + RESOURCE, + use_ssl=False, + ) as client_ws: pass exc = exc_info.value - assert exc.body == b'My body' + assert exc.body == b"My body" @fail_after(1) async def test_reject_handshake_invalid_info_status(nursery): - ''' + """ An informational status code that is not 101 should cause the client to reject the handshake. Since it is an informational response, there will not be a response body, so this test exercises a different code path. - ''' + """ + async def handler(stream): - await stream.send_all(b'HTTP/1.1 100 CONTINUE\r\n\r\n') + await stream.send_all(b"HTTP/1.1 100 CONTINUE\r\n\r\n") await stream.receive_some(max_bytes=1024) + serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) listeners = await nursery.start(serve_fn) port = listeners[0].socket.getsockname()[1] with pytest.raises(ConnectionRejected) as exc_info: - async with open_websocket(HOST, port, RESOURCE, use_ssl=False, - ) as client_ws: + async with open_websocket( + HOST, + port, + RESOURCE, + use_ssl=False, + ) as client_ws: pass exc = exc_info.value assert exc.status_code == 100 - assert repr(exc) == 'ConnectionRejected' + assert repr(exc) == "ConnectionRejected" assert exc.body is None async def test_handshake_protocol_error(nursery, echo_server): - ''' + """ If a client connects to a trio-websocket server and tries to speak HTTP instead of WebSocket, the server should reject the connection. (If the server does not catch the protocol exception, it will raise an exception up to the nursery level and fail the test.) - ''' + """ client_stream = await trio.open_tcp_stream(HOST, echo_server.port) async with client_stream: - await client_stream.send_all(b'GET / HTTP/1.1\r\n\r\n') + await client_stream.send_all(b"GET / HTTP/1.1\r\n\r\n") response = await client_stream.receive_some(1024) - assert response.startswith(b'HTTP/1.1 400') + assert response.startswith(b"HTTP/1.1 400") async def test_client_send_and_receive(echo_conn): async with echo_conn: - await echo_conn.send_message('This is a test message.') + await echo_conn.send_message("This is a test message.") received_msg = await echo_conn.get_message() - assert received_msg == 'This is a test message.' + assert received_msg == "This is a test message." async def test_client_send_invalid_type(echo_conn): @@ -505,17 +531,19 @@ async def test_client_send_invalid_type(echo_conn): async def test_client_ping(echo_conn): async with echo_conn: - await echo_conn.ping(b'A') + await echo_conn.ping(b"A") with pytest.raises(ConnectionClosed): - await echo_conn.ping(b'B') + await echo_conn.ping(b"B") async def test_client_ping_two_payloads(echo_conn): pong_count = 0 + async def ping_and_count(): nonlocal pong_count await echo_conn.ping() pong_count += 1 + async with echo_conn: async with trio.open_nursery() as nursery: nursery.start_soon(ping_and_count) @@ -528,12 +556,14 @@ async def test_client_ping_same_payload(echo_conn): # same time. One of them should succeed and the other should get an # exception. exc_count = 0 + async def ping_and_catch(): nonlocal exc_count try: - await echo_conn.ping(b'A') + await echo_conn.ping(b"A") except ValueError: exc_count += 1 + async with echo_conn: async with trio.open_nursery() as nursery: nursery.start_soon(ping_and_catch) @@ -543,9 +573,9 @@ async def ping_and_catch(): async def test_client_pong(echo_conn): async with echo_conn: - await echo_conn.pong(b'A') + await echo_conn.pong(b"A") with pytest.raises(ConnectionClosed): - await echo_conn.pong(b'B') + await echo_conn.pong(b"B") async def test_client_default_close(echo_conn): @@ -553,16 +583,18 @@ async def test_client_default_close(echo_conn): assert not echo_conn.closed assert echo_conn.closed.code == 1000 assert echo_conn.closed.reason is None - assert repr(echo_conn.closed) == 'CloseReason' + assert ( + repr(echo_conn.closed) == "CloseReason" + ) async def test_client_nondefault_close(echo_conn): async with echo_conn: assert not echo_conn.closed - await echo_conn.aclose(code=1001, reason='test reason') + await echo_conn.aclose(code=1001, reason="test reason") assert echo_conn.closed.code == 1001 - assert echo_conn.closed.reason == 'test reason' + assert echo_conn.closed.reason == "test reason" async def test_wrap_client_stream(nursery): @@ -573,10 +605,10 @@ async def test_wrap_client_stream(nursery): conn = await wrap_client_stream(nursery, stream, HOST, RESOURCE) async with conn: assert not conn.closed - await conn.send_message('Hello from client!') + await conn.send_message("Hello from client!") msg = await conn.get_message() - assert msg == 'Hello from client!' - assert conn.local.startswith('StapledStream(') + assert msg == "Hello from client!" + assert conn.local.startswith("StapledStream(") assert conn.closed @@ -587,38 +619,42 @@ async def handler(stream): async with server_ws: assert not server_ws.closed msg = await server_ws.get_message() - assert msg == 'Hello from client!' + assert msg == "Hello from client!" assert server_ws.closed + serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) listeners = await nursery.start(serve_fn) port = listeners[0].socket.getsockname()[1] async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as client: - await client.send_message('Hello from client!') + await client.send_message("Hello from client!") @fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_client_open_timeout(nursery, autojump_clock): - ''' + """ The client times out waiting for the server to complete the opening handshake. - ''' + """ + async def handler(request): await trio.sleep(FORCE_TIMEOUT) server_ws = await request.accept() - pytest.fail('Should not reach this line.') + pytest.fail("Should not reach this line.") server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) with pytest.raises(ConnectionTimeout): - async with open_websocket(HOST, server.port, '/', use_ssl=False, - connect_timeout=TIMEOUT) as client_ws: + async with open_websocket( + HOST, server.port, "/", use_ssl=False, connect_timeout=TIMEOUT + ) as client_ws: pass @fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_client_close_timeout(nursery, autojump_clock): - ''' + """ This client times out waiting for the server to complete the closing handshake. @@ -626,68 +662,83 @@ async def test_client_close_timeout(nursery, autojump_clock): queue size is 0, and the client sends it exactly 1 message. This blocks the server's reader so it won't do the closing handshake for at least ``FORCE_TIMEOUT`` seconds. - ''' + """ + async def handler(request): server_ws = await request.accept() await trio.sleep(FORCE_TIMEOUT) # The next line should raise ConnectionClosed. await server_ws.get_message() - pytest.fail('Should not reach this line.') + pytest.fail("Should not reach this line.") server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None, - message_queue_size=0)) + partial( + serve_websocket, handler, HOST, 0, ssl_context=None, message_queue_size=0 + ) + ) with pytest.raises(DisconnectionTimeout): - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - disconnect_timeout=TIMEOUT) as client_ws: - await client_ws.send_message('test') + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, disconnect_timeout=TIMEOUT + ) as client_ws: + await client_ws.send_message("test") async def test_client_connect_networking_error(): - with patch('trio_websocket._impl.connect_websocket') as \ - connect_websocket_mock: + with patch("trio_websocket._impl.connect_websocket") as connect_websocket_mock: connect_websocket_mock.side_effect = OSError() with pytest.raises(HandshakeError): - async with open_websocket(HOST, 0, '/', use_ssl=False) as client_ws: + async with open_websocket(HOST, 0, "/", use_ssl=False) as client_ws: pass @fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_server_open_timeout(autojump_clock): - ''' + """ The server times out waiting for the client to complete the opening handshake. Server timeouts don't raise exceptions, because handler tasks are launched in an internal nursery and sending exceptions wouldn't be helpful. Instead, timed out tasks silently end. - ''' + """ + async def handler(request): - pytest.fail('This handler should not be called.') + pytest.fail("This handler should not be called.") async with trio.open_nursery() as nursery: - server = await nursery.start(partial(serve_websocket, handler, HOST, 0, - ssl_context=None, handler_nursery=nursery, connect_timeout=TIMEOUT)) + server = await nursery.start( + partial( + serve_websocket, + handler, + HOST, + 0, + ssl_context=None, + handler_nursery=nursery, + connect_timeout=TIMEOUT, + ) + ) old_task_count = len(nursery.child_tasks) # This stream is not a WebSocket, so it won't send a handshake: stream = await trio.open_tcp_stream(HOST, server.port) # Checkpoint so the server's handler task can spawn: await trio.sleep(0) - assert len(nursery.child_tasks) == old_task_count + 1, \ - "Server's reader task did not spawn" + assert ( + len(nursery.child_tasks) == old_task_count + 1 + ), "Server's reader task did not spawn" # Sleep long enough to trigger server's connect_timeout: await trio.sleep(FORCE_TIMEOUT) - assert len(nursery.child_tasks) == old_task_count, \ - "Server's reader task is still running" + assert ( + len(nursery.child_tasks) == old_task_count + ), "Server's reader task is still running" # Cancel the server task: nursery.cancel_scope.cancel() @fail_after(TIMEOUT_TEST_MAX_DURATION) async def test_server_close_timeout(autojump_clock): - ''' + """ The server times out waiting for the client to complete the closing handshake. @@ -698,33 +749,45 @@ async def test_server_close_timeout(autojump_clock): To prevent the client from doing the closing handshake, we make sure that its message queue size is 0 and the server sends it exactly 1 message. This blocks the client's reader and prevents it from doing the client handshake. - ''' + """ + async def handler(request): ws = await request.accept() # Send one message to block the client's reader task: - await ws.send_message('test') + await ws.send_message("test") async with trio.open_nursery() as outer: - server = await outer.start(partial(serve_websocket, handler, HOST, 0, - ssl_context=None, handler_nursery=outer, - disconnect_timeout=TIMEOUT)) + server = await outer.start( + partial( + serve_websocket, + handler, + HOST, + 0, + ssl_context=None, + handler_nursery=outer, + disconnect_timeout=TIMEOUT, + ) + ) old_task_count = len(outer.child_tasks) # Spawn client inside an inner nursery so that we can cancel it's reader # so that it won't do a closing handshake. async with trio.open_nursery() as inner: - ws = await connect_websocket(inner, HOST, server.port, RESOURCE, - use_ssl=False) + ws = await connect_websocket( + inner, HOST, server.port, RESOURCE, use_ssl=False + ) # Checkpoint so the server can spawn a handler task: await trio.sleep(0) - assert len(outer.child_tasks) == old_task_count + 1, \ - "Server's reader task did not spawn" + assert ( + len(outer.child_tasks) == old_task_count + 1 + ), "Server's reader task did not spawn" # The client waits long enough to trigger the server's disconnect # timeout: await trio.sleep(FORCE_TIMEOUT) # The server should have cancelled the handler: - assert len(outer.child_tasks) == old_task_count, \ - "Server's reader task is still running" + assert ( + len(outer.child_tasks) == old_task_count + ), "Server's reader task is still running" # Cancel the client's reader task: inner.cancel_scope.cancel() @@ -737,13 +800,14 @@ async def handler(request): server_ws = await request.accept() with pytest.raises(ConnectionClosed): await server_ws.get_message() + server = await nursery.start(serve_websocket, handler, HOST, 0, None) stream = await trio.open_tcp_stream(HOST, server.port) client_ws = await wrap_client_stream(nursery, stream, HOST, RESOURCE) async with client_ws: await stream.aclose() with pytest.raises(ConnectionClosed): - await client_ws.send_message('Hello from client!') + await client_ws.send_message("Hello from client!") async def test_server_sends_after_close(nursery): @@ -753,7 +817,7 @@ async def handler(request): server_ws = await request.accept() with pytest.raises(ConnectionClosed): while True: - await server_ws.send_message('Hello from server') + await server_ws.send_message("Hello from server") done.set() server = await nursery.start(serve_websocket, handler, HOST, 0, None) @@ -762,7 +826,7 @@ async def handler(request): async with client_ws: # pump a few messages for x in range(2): - await client_ws.send_message('Hello from client') + await client_ws.send_message("Hello from client") await stream.aclose() await done.wait() @@ -774,7 +838,8 @@ async def handler(stream): async with server_ws: await stream.aclose() with pytest.raises(ConnectionClosed): - await server_ws.send_message('Hello from client!') + await server_ws.send_message("Hello from client!") + serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) listeners = await nursery.start(serve_fn) port = listeners[0].socket.getsockname()[1] @@ -789,69 +854,72 @@ async def handler(request): await trio.sleep(1) server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) # connection should close when server handler exits with trio.fail_after(2): - async with open_websocket( - HOST, server.port, '/', use_ssl=False) as connection: + async with open_websocket(HOST, server.port, "/", use_ssl=False) as connection: with pytest.raises(ConnectionClosed) as exc_info: await connection.get_message() exc = exc_info.value - assert exc.reason.name == 'NORMAL_CLOSURE' + assert exc.reason.name == "NORMAL_CLOSURE" @fail_after(DEFAULT_TEST_MAX_DURATION) async def test_read_messages_after_remote_close(nursery): - ''' + """ When the remote endpoint closes, the local endpoint can still read all of the messages sent prior to closing. Any attempt to read beyond that will raise ConnectionClosed. This test also exercises the configuration of the queue size. - ''' + """ server_closed = trio.Event() async def handler(request): server = await request.accept() async with server: - await server.send_message('1') - await server.send_message('2') + await server.send_message("1") + await server.send_message("2") server_closed.set() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) # The client needs a message queue of size 2 so that it can buffer both # incoming messages without blocking the reader task. - async with open_websocket(HOST, server.port, '/', use_ssl=False, - message_queue_size=2) as client: + async with open_websocket( + HOST, server.port, "/", use_ssl=False, message_queue_size=2 + ) as client: await server_closed.wait() - assert await client.get_message() == '1' - assert await client.get_message() == '2' + assert await client.get_message() == "1" + assert await client.get_message() == "2" with pytest.raises(ConnectionClosed): await client.get_message() async def test_no_messages_after_local_close(nursery): - ''' + """ If the local endpoint initiates closing, then pending messages are discarded and any attempt to read a message will raise ConnectionClosed. - ''' + """ client_closed = trio.Event() async def handler(request): # The server sends some messages and then closes. server = await request.accept() async with server: - await server.send_message('1') - await server.send_message('2') + await server.send_message("1") + await server.send_message("2") await client_closed.wait() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - async with open_websocket(HOST, server.port, '/', use_ssl=False) as client: + async with open_websocket(HOST, server.port, "/", use_ssl=False) as client: pass with pytest.raises(ConnectionClosed): await client.get_message() @@ -859,28 +927,30 @@ async def handler(request): async def test_cm_exit_with_pending_messages(echo_server, autojump_clock): - ''' + """ Regression test for #74, where a context manager was not able to exit when there were pending messages in the receive queue. - ''' + """ with trio.fail_after(1): - async with open_websocket(HOST, echo_server.port, RESOURCE, - use_ssl=False) as ws: - await ws.send_message('hello') + async with open_websocket( + HOST, echo_server.port, RESOURCE, use_ssl=False + ) as ws: + await ws.send_message("hello") # allow time for the server to respond - await trio.sleep(.1) + await trio.sleep(0.1) @fail_after(DEFAULT_TEST_MAX_DURATION) async def test_max_message_size(nursery): - ''' + """ Set the client's max message size to 100 bytes. The client can send a message larger than 100 bytes, but when it receives a message larger than 100 bytes, it closes the connection with code 1009. - ''' + """ + async def handler(request): - ''' Similar to the echo_request_handler fixture except it runs in a - loop. ''' + """Similar to the echo_request_handler fixture except it runs in a + loop.""" conn = await request.accept() while True: try: @@ -890,16 +960,18 @@ async def handler(request): break server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - max_message_size=100) as client: + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, max_message_size=100 + ) as client: # We can send and receive 100 bytes: - await client.send_message(b'A' * 100) + await client.send_message(b"A" * 100) msg = await client.get_message() assert len(msg) == 100 # We can send 101 bytes but cannot receive 101 bytes: - await client.send_message(b'B' * 101) + await client.send_message(b"B" * 101) with pytest.raises(ConnectionClosed): await client.get_message() assert client.closed @@ -912,19 +984,21 @@ async def test_server_close_client_disconnect_race(nursery, autojump_clock): async def handler(request: WebSocketRequest): ws = await request.accept() ws._for_testing_peer_closed_connection = trio.Event() - await ws.send_message('foo') + await ws.send_message("foo") await ws._for_testing_peer_closed_connection.wait() # with bug, this would raise ConnectionClosed from websocket internal task await trio.aclose_forcefully(ws._stream) server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - connection = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + connection = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) await connection.get_message() await connection.aclose() - await trio.sleep(.1) + await trio.sleep(0.1) async def test_remote_close_local_message_race(nursery, autojump_clock): @@ -944,15 +1018,17 @@ async def handler(request: WebSocketRequest): await ws.aclose() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - client = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + client = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) client._for_testing_peer_closed_connection = trio.Event() - await client.send_message('foo') + await client.send_message("foo") await client._for_testing_peer_closed_connection.wait() with pytest.raises(ConnectionClosed): - await client.send_message('bar') + await client.send_message("bar") async def test_message_after_local_close_race(nursery): @@ -963,10 +1039,12 @@ async def handler(request: WebSocketRequest): await trio.sleep_forever() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) - client = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + client = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) orig_send = client._send close_sent = trio.Event() @@ -981,7 +1059,7 @@ async def _send_wrapper(event): await close_sent.wait() assert client.closed with pytest.raises(ConnectionClosed): - await client.send_message('hello') + await client.send_message("hello") @fail_after(DEFAULT_TEST_MAX_DURATION) @@ -999,9 +1077,11 @@ async def handle_connection(request): await trio.sleep_forever() server = await nursery.start( - partial(serve_websocket, handle_connection, HOST, 0, ssl_context=None)) - client = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + partial(serve_websocket, handle_connection, HOST, 0, ssl_context=None) + ) + client = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) # send a CloseConnection event to server but leave client connected await client._send(CloseConnection(code=1000)) await server_stream_closed.wait() diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 9c8c90e..e774229 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -30,13 +30,13 @@ ) import wsproto.utilities -_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.'))) < (0, 22, 0) +_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split("."))) < (0, 22, 0) -CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds +CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds MESSAGE_QUEUE_SIZE = 1 -MAX_MESSAGE_SIZE = 2 ** 20 # 1 MiB -RECEIVE_BYTES = 4 * 2 ** 10 # 4 KiB -logger = logging.getLogger('trio-websocket') +MAX_MESSAGE_SIZE = 2**20 # 1 MiB +RECEIVE_BYTES = 4 * 2**10 # 4 KiB +logger = logging.getLogger("trio-websocket") def _ignore_cancel(exc): @@ -53,6 +53,7 @@ class _preserve_current_exception: https://github.com/python-trio/trio/issues/1559 https://gitter.im/python-trio/general?at=5faf2293d37a1a13d6a582cf """ + __slots__ = ("_armed",) def __init__(self): @@ -66,20 +67,33 @@ def __exit__(self, ty, value, tb): return False if _TRIO_MULTI_ERROR: - filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # pylint: disable=no-member + filtered_exception = trio.MultiError.filter( + _ignore_cancel, value + ) # pylint: disable=no-member elif isinstance(value, BaseExceptionGroup): - filtered_exception = value.subgroup(lambda exc: not isinstance(exc, trio.Cancelled)) + filtered_exception = value.subgroup( + lambda exc: not isinstance(exc, trio.Cancelled) + ) else: filtered_exception = _ignore_cancel(value) return filtered_exception is None @asynccontextmanager -async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, +async def open_websocket( + host, + port, + resource, + *, + use_ssl, + subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, - connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): - ''' + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, + connect_timeout=CONN_TIMEOUT, + disconnect_timeout=CONN_TIMEOUT, +): + """ Open a WebSocket client connection to a host. This async context manager connects when entering the context manager and @@ -110,15 +124,21 @@ async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, :raises HandshakeError: for any networking error, client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`), or server rejection (:exc:`ConnectionRejected`) during handshakes. - ''' + """ async with trio.open_nursery() as new_nursery: try: with trio.fail_after(connect_timeout): - connection = await connect_websocket(new_nursery, host, port, - resource, use_ssl=use_ssl, subprotocols=subprotocols, + connection = await connect_websocket( + new_nursery, + host, + port, + resource, + use_ssl=use_ssl, + subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) except trio.TooSlowError: raise ConnectionTimeout from None except OSError as e: @@ -133,10 +153,19 @@ async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, raise DisconnectionTimeout from None -async def connect_websocket(nursery, host, port, resource, *, use_ssl, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): - ''' +async def connect_websocket( + nursery, + host, + port, + resource, + *, + use_ssl, + subprotocols=None, + extra_headers=None, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, +): + """ Return an open WebSocket client connection to a host. This function is used to specify a custom nursery to run connection @@ -164,7 +193,7 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection - ''' + """ if use_ssl is True: ssl_context = ssl.create_default_context() elif use_ssl is False: @@ -172,36 +201,52 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, elif isinstance(use_ssl, ssl.SSLContext): ssl_context = use_ssl else: - raise TypeError('`use_ssl` argument must be bool or ssl.SSLContext') - - logger.debug('Connecting to ws%s://%s:%d%s', - '' if ssl_context is None else 's', host, port, resource) + raise TypeError("`use_ssl` argument must be bool or ssl.SSLContext") + + logger.debug( + "Connecting to ws%s://%s:%d%s", + "" if ssl_context is None else "s", + host, + port, + resource, + ) if ssl_context is None: stream = await trio.open_tcp_stream(host, port) else: - stream = await trio.open_ssl_over_tcp_stream(host, port, - ssl_context=ssl_context, https_compatible=True) + stream = await trio.open_ssl_over_tcp_stream( + host, port, ssl_context=ssl_context, https_compatible=True + ) if port in (80, 443): host_header = host else: - host_header = f'{host}:{port}' - connection = WebSocketConnection(stream, + host_header = f"{host}:{port}" + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.CLIENT), host=host_header, path=resource, - client_subprotocols=subprotocols, client_extra_headers=extra_headers, + client_subprotocols=subprotocols, + client_extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection -def open_websocket_url(url, ssl_context=None, *, subprotocols=None, +def open_websocket_url( + url, + ssl_context=None, + *, + subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, - connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): - ''' + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, + connect_timeout=CONN_TIMEOUT, + disconnect_timeout=CONN_TIMEOUT, +): + """ Open a WebSocket client connection to a URL. This async context manager connects when entering the context manager and @@ -230,19 +275,33 @@ def open_websocket_url(url, ssl_context=None, *, subprotocols=None, :raises HandshakeError: for any networking error, client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`), or server rejection (:exc:`ConnectionRejected`) during handshakes. - ''' + """ host, port, resource, ssl_context = _url_to_host(url, ssl_context) - return open_websocket(host, port, resource, use_ssl=ssl_context, - subprotocols=subprotocols, extra_headers=extra_headers, + return open_websocket( + host, + port, + resource, + use_ssl=ssl_context, + subprotocols=subprotocols, + extra_headers=extra_headers, message_queue_size=message_queue_size, max_message_size=max_message_size, - connect_timeout=connect_timeout, disconnect_timeout=disconnect_timeout) + connect_timeout=connect_timeout, + disconnect_timeout=disconnect_timeout, + ) -async def connect_websocket_url(nursery, url, ssl_context=None, *, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): - ''' +async def connect_websocket_url( + nursery, + url, + ssl_context=None, + *, + subprotocols=None, + extra_headers=None, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, +): + """ Return an open WebSocket client connection to a URL. This function is used to specify a custom nursery to run connection @@ -267,16 +326,23 @@ async def connect_websocket_url(nursery, url, ssl_context=None, *, ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection - ''' + """ host, port, resource, ssl_context = _url_to_host(url, ssl_context) - return await connect_websocket(nursery, host, port, resource, - use_ssl=ssl_context, subprotocols=subprotocols, - extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + return await connect_websocket( + nursery, + host, + port, + resource, + use_ssl=ssl_context, + subprotocols=subprotocols, + extra_headers=extra_headers, + message_queue_size=message_queue_size, + max_message_size=max_message_size, + ) def _url_to_host(url, ssl_context): - ''' + """ Convert a WebSocket URL to a (host,port,resource) tuple. The returned ``ssl_context`` is either the same object that was passed in, @@ -286,15 +352,15 @@ def _url_to_host(url, ssl_context): :param str url: A WebSocket URL. :type ssl_context: ssl.SSLContext or None :returns: A tuple of ``(host, port, resource, ssl_context)``. - ''' + """ url = str(url) # For backward compat with isinstance(url, yarl.URL). parts = urllib.parse.urlsplit(url) - if parts.scheme not in ('ws', 'wss'): + if parts.scheme not in ("ws", "wss"): raise ValueError('WebSocket URL scheme must be "ws:" or "wss:"') if ssl_context is None: - ssl_context = parts.scheme == 'wss' - elif parts.scheme == 'ws': - raise ValueError('SSL context must be None for ws: URL scheme') + ssl_context = parts.scheme == "wss" + elif parts.scheme == "ws": + raise ValueError("SSL context must be None for ws: URL scheme") host = parts.hostname if parts.port is not None: port = parts.port @@ -305,16 +371,24 @@ def _url_to_host(url, ssl_context): # If the target URI's path component is empty, the client MUST # send "/" as the path within the origin-form of request-target. if not path_qs: - path_qs = '/' - if '?' in url: - path_qs += '?' + parts.query + path_qs = "/" + if "?" in url: + path_qs += "?" + parts.query return host, port, path_qs, ssl_context -async def wrap_client_stream(nursery, stream, host, resource, *, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): - ''' +async def wrap_client_stream( + nursery, + stream, + host, + resource, + *, + subprotocols=None, + extra_headers=None, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, +): + """ Wrap an arbitrary stream in a WebSocket connection. This is a low-level function only needed in rare cases. In most cases, you @@ -338,21 +412,29 @@ async def wrap_client_stream(nursery, stream, host, resource, *, ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection - ''' - connection = WebSocketConnection(stream, + """ + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.CLIENT), - host=host, path=resource, - client_subprotocols=subprotocols, client_extra_headers=extra_headers, + host=host, + path=resource, + client_subprotocols=subprotocols, + client_extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection -async def wrap_server_stream(nursery, stream, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): - ''' +async def wrap_server_stream( + nursery, + stream, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, +): + """ Wrap an arbitrary stream in a server-side WebSocket. This is a low-level function only needed in rare cases. In most cases, you @@ -367,21 +449,32 @@ async def wrap_server_stream(nursery, stream, then the connection is closed with code 1009 (Message Too Big). :type stream: trio.abc.Stream :rtype: WebSocketRequest - ''' - connection = WebSocketConnection(stream, + """ + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.SERVER), message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) nursery.start_soon(connection._reader_task) request = await connection._get_request() return request -async def serve_websocket(handler, host, port, ssl_context, *, - handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, - disconnect_timeout=CONN_TIMEOUT, task_status=trio.TASK_STATUS_IGNORED): - ''' +async def serve_websocket( + handler, + host, + port, + ssl_context, + *, + handler_nursery=None, + message_queue_size=MESSAGE_QUEUE_SIZE, + max_message_size=MAX_MESSAGE_SIZE, + connect_timeout=CONN_TIMEOUT, + disconnect_timeout=CONN_TIMEOUT, + task_status=trio.TASK_STATUS_IGNORED, +): + """ Serve a WebSocket over TCP. This function supports the Trio nursery start protocol: ``server = await @@ -415,64 +508,79 @@ async def serve_websocket(handler, host, port, ssl_context, *, to finish the closing handshake before timing out. :param task_status: Part of Trio nursery start protocol. :returns: This function runs until cancelled. - ''' + """ if ssl_context is None: open_tcp_listeners = partial(trio.open_tcp_listeners, port, host=host) else: - open_tcp_listeners = partial(trio.open_ssl_over_tcp_listeners, port, - ssl_context, host=host, https_compatible=True) + open_tcp_listeners = partial( + trio.open_ssl_over_tcp_listeners, + port, + ssl_context, + host=host, + https_compatible=True, + ) listeners = await open_tcp_listeners() - server = WebSocketServer(handler, listeners, - handler_nursery=handler_nursery, message_queue_size=message_queue_size, - max_message_size=max_message_size, connect_timeout=connect_timeout, - disconnect_timeout=disconnect_timeout) + server = WebSocketServer( + handler, + listeners, + handler_nursery=handler_nursery, + message_queue_size=message_queue_size, + max_message_size=max_message_size, + connect_timeout=connect_timeout, + disconnect_timeout=disconnect_timeout, + ) await server.run(task_status=task_status) class HandshakeError(Exception): - ''' + """ There was an error during connection or disconnection with the websocket server. - ''' + """ + class ConnectionTimeout(HandshakeError): - '''There was a timeout when connecting to the websocket server.''' + """There was a timeout when connecting to the websocket server.""" + class DisconnectionTimeout(HandshakeError): - '''There was a timeout when disconnecting from the websocket server.''' + """There was a timeout when disconnecting from the websocket server.""" + class ConnectionClosed(Exception): - ''' + """ A WebSocket operation cannot be completed because the connection is closed or in the process of closing. - ''' + """ + def __init__(self, reason): - ''' + """ Constructor. :param reason: :type reason: CloseReason - ''' + """ super().__init__() self.reason = reason def __repr__(self): - ''' Return representation. ''' - return f'{self.__class__.__name__}<{self.reason}>' + """Return representation.""" + return f"{self.__class__.__name__}<{self.reason}>" class ConnectionRejected(HandshakeError): - ''' + """ A WebSocket connection could not be established because the server rejected the connection attempt. - ''' + """ + def __init__(self, status_code, headers, body): - ''' + """ Constructor. :param reason: :type reason: CloseReason - ''' + """ super().__init__() #: a 3 digit HTTP status code self.status_code = status_code @@ -482,144 +590,149 @@ def __init__(self, status_code, headers, body): self.body = body def __repr__(self): - ''' Return representation. ''' - return f'{self.__class__.__name__}' + """Return representation.""" + return f"{self.__class__.__name__}" class CloseReason: - ''' Contains information about why a WebSocket was closed. ''' + """Contains information about why a WebSocket was closed.""" + def __init__(self, code, reason): - ''' + """ Constructor. :param int code: :param Optional[str] reason: - ''' + """ self._code = code try: self._name = wsframeproto.CloseReason(code).name except ValueError: if 1000 <= code <= 2999: - self._name = 'RFC_RESERVED' + self._name = "RFC_RESERVED" elif 3000 <= code <= 3999: - self._name = 'IANA_RESERVED' + self._name = "IANA_RESERVED" elif 4000 <= code <= 4999: - self._name = 'PRIVATE_RESERVED' + self._name = "PRIVATE_RESERVED" else: - self._name = 'INVALID_CODE' + self._name = "INVALID_CODE" self._reason = reason @property def code(self): - ''' (Read-only) The numeric close code. ''' + """(Read-only) The numeric close code.""" return self._code @property def name(self): - ''' (Read-only) The human-readable close code. ''' + """(Read-only) The human-readable close code.""" return self._name @property def reason(self): - ''' (Read-only) An arbitrary reason string. ''' + """(Read-only) An arbitrary reason string.""" return self._reason def __repr__(self): - ''' Show close code, name, and reason. ''' - return f'{self.__class__.__name__}' \ - f'' + """Show close code, name, and reason.""" + return ( + f"{self.__class__.__name__}" + f"" + ) class Future: - ''' Represents a value that will be available in the future. ''' + """Represents a value that will be available in the future.""" + def __init__(self): - ''' Constructor. ''' + """Constructor.""" self._value = None self._value_event = trio.Event() def set_value(self, value): - ''' + """ Set a value, which will notify any waiters. :param value: - ''' + """ self._value = value self._value_event.set() async def wait_value(self): - ''' + """ Wait for this future to have a value, then return it. :returns: The value set by ``set_value()``. - ''' + """ await self._value_event.wait() return self._value class WebSocketRequest: - ''' + """ Represents a handshake presented by a client to a server. The server may modify the handshake or leave it as is. The server should call ``accept()`` to finish the handshake and obtain a connection object. - ''' + """ + def __init__(self, connection, event): - ''' + """ Constructor. :param WebSocketConnection connection: :type event: wsproto.events.Request - ''' + """ self._connection = connection self._event = event @property def headers(self): - ''' + """ HTTP headers represented as a list of (name, value) pairs. :rtype: list[tuple] - ''' + """ return self._event.extra_headers @property def path(self): - ''' + """ The requested URL path. :rtype: str - ''' + """ return self._event.target @property def proposed_subprotocols(self): - ''' + """ A tuple of protocols proposed by the client. :rtype: tuple[str] - ''' + """ return tuple(self._event.subprotocols) @property def local(self): - ''' + """ The connection's local endpoint. :rtype: Endpoint or str - ''' + """ return self._connection.local @property def remote(self): - ''' + """ The connection's remote endpoint. :rtype: Endpoint or str - ''' + """ return self._connection.remote async def accept(self, *, subprotocol=None, extra_headers=None): - ''' + """ Accept the request and return a connection object. :param subprotocol: The selected subprotocol for this connection. @@ -628,14 +741,14 @@ async def accept(self, *, subprotocol=None, extra_headers=None): send as HTTP headers. :type extra_headers: list[tuple[bytes,bytes]] or None :rtype: WebSocketConnection - ''' + """ if extra_headers is None: extra_headers = [] await self._connection._accept(self._event, subprotocol, extra_headers) return self._connection async def reject(self, status_code, *, extra_headers=None, body=None): - ''' + """ Reject the handshake. :param int status_code: The 3 digit HTTP status code. In order to be @@ -646,14 +759,14 @@ async def reject(self, status_code, *, extra_headers=None, body=None): :param body: If provided, this data will be sent in the response body, otherwise no response body will be sent. :type body: bytes or None - ''' + """ extra_headers = extra_headers or [] - body = body or b'' + body = body or b"" await self._connection._reject(status_code, extra_headers, body) def _get_stream_endpoint(stream, *, local): - ''' + """ Construct an endpoint from a stream. :param trio.Stream stream: @@ -661,7 +774,7 @@ def _get_stream_endpoint(stream, *, local): :returns: An endpoint instance or ``repr()`` for streams that cannot be represented as an endpoint. :rtype: Endpoint or str - ''' + """ socket, is_ssl = None, False if isinstance(stream, trio.SocketStream): socket = stream.socket @@ -677,15 +790,23 @@ def _get_stream_endpoint(stream, *, local): class WebSocketConnection(trio.abc.AsyncResource): - ''' A WebSocket connection. ''' + """A WebSocket connection.""" CONNECTION_ID = itertools.count() - def __init__(self, stream, ws_connection, *, host=None, path=None, - client_subprotocols=None, client_extra_headers=None, + def __init__( + self, + stream, + ws_connection, + *, + host=None, + path=None, + client_subprotocols=None, + client_extra_headers=None, message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE): - ''' + max_message_size=MAX_MESSAGE_SIZE, + ): + """ Constructor. Generally speaking, users are discouraged from directly instantiating a @@ -710,7 +831,7 @@ def __init__(self, stream, ws_connection, *, host=None, path=None, :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). - ''' + """ # NOTE: The implementation uses _close_reason for more than an advisory # purpose. It's critical internal state, indicating when the # connection is closed or closing. @@ -724,9 +845,12 @@ def __init__(self, stream, ws_connection, *, host=None, path=None, self._max_message_size = max_message_size self._reader_running = True if ws_connection.client: - self._initial_request: Optional[Request] = Request(host=host, target=path, + self._initial_request: Optional[Request] = Request( + host=host, + target=path, subprotocols=client_subprotocols, - extra_headers=client_extra_headers or []) + extra_headers=client_extra_headers or [], + ) else: self._initial_request = None self._path = path @@ -734,9 +858,10 @@ def __init__(self, stream, ws_connection, *, host=None, path=None, self._handshake_headers = tuple() self._reject_status = 0 self._reject_headers = tuple() - self._reject_body = b'' + self._reject_body = b"" self._send_channel, self._recv_channel = trio.open_memory_channel( - message_queue_size) + message_queue_size + ) self._pings = OrderedDict() # Set when the server has received a connection request event. This # future is never set on client connections. @@ -753,77 +878,77 @@ def __init__(self, stream, ws_connection, *, host=None, path=None, @property def closed(self): - ''' + """ (Read-only) The reason why the connection was or is being closed, else ``None``. :rtype: Optional[CloseReason] - ''' + """ return self._close_reason @property def is_client(self): - ''' (Read-only) Is this a client instance? ''' + """(Read-only) Is this a client instance?""" return self._wsproto.client @property def is_server(self): - ''' (Read-only) Is this a server instance? ''' + """(Read-only) Is this a server instance?""" return not self._wsproto.client @property def local(self): - ''' + """ The local endpoint of the connection. :rtype: Endpoint or str - ''' + """ return _get_stream_endpoint(self._stream, local=True) @property def remote(self): - ''' + """ The remote endpoint of the connection. :rtype: Endpoint or str - ''' + """ return _get_stream_endpoint(self._stream, local=False) @property def path(self): - ''' + """ The requested URL path. For clients, this is set when the connection is instantiated. For servers, it is set after the handshake completes. :rtype: str - ''' + """ return self._path @property def subprotocol(self): - ''' + """ (Read-only) The negotiated subprotocol, or ``None`` if there is no subprotocol. This is only valid after the opening handshake is complete. :rtype: str or None - ''' + """ return self._subprotocol @property def handshake_headers(self): - ''' + """ The HTTP headers that were sent by the remote during the handshake, stored as 2-tuples containing key/value pairs. Header keys are always lower case. :rtype: tuple[tuple[str,str]] - ''' + """ return self._handshake_headers async def aclose(self, code=1000, reason=None): # pylint: disable=arguments-differ - ''' + """ Close the WebSocket connection. This sends a closing frame and suspends until the connection is closed. @@ -836,7 +961,7 @@ async def aclose(self, code=1000, reason=None): # pylint: disable=arguments-dif :param int code: A 4-digit code number indicating the type of closure. :param str reason: An optional string describing the closure. - ''' + """ with _preserve_current_exception(): await self._aclose(code, reason) @@ -851,8 +976,10 @@ async def _aclose(self, code, reason): # event to peer, while setting the local close reason to normal. self._close_reason = CloseReason(1000, None) await self._send(CloseConnection(code=code, reason=reason)) - elif self._wsproto.state in (ConnectionState.CONNECTING, - ConnectionState.REJECTING): + elif self._wsproto.state in ( + ConnectionState.CONNECTING, + ConnectionState.REJECTING, + ): self._close_handshake.set() # TODO: shouldn't the receive channel be closed earlier, so that # get_message() during send of the CloseConneciton event fails? @@ -867,7 +994,7 @@ async def _aclose(self, code, reason): await self._close_stream() async def get_message(self): - ''' + """ Receive the next WebSocket message. If no message is available immediately, then this function blocks until @@ -882,7 +1009,7 @@ async def get_message(self): :rtype: str or bytes :raises ConnectionClosed: if the connection is closed. - ''' + """ try: message = await self._recv_channel.receive() except (trio.ClosedResourceError, trio.EndOfChannel): @@ -890,7 +1017,7 @@ async def get_message(self): return message async def ping(self, payload=None): - ''' + """ Send WebSocket ping to remote endpoint and wait for a correspoding pong. Each in-flight ping must include a unique payload. This function sends @@ -908,39 +1035,39 @@ async def ping(self, payload=None): :raises ConnectionClosed: if connection is closed. :raises ValueError: if ``payload`` is identical to another in-flight ping. - ''' + """ if self._close_reason: raise ConnectionClosed(self._close_reason) if payload in self._pings: - raise ValueError(f'Payload value {payload} is already in flight.') + raise ValueError(f"Payload value {payload} is already in flight.") if payload is None: - payload = struct.pack('!I', random.getrandbits(32)) + payload = struct.pack("!I", random.getrandbits(32)) event = trio.Event() self._pings[payload] = event await self._send(Ping(payload=payload)) await event.wait() async def pong(self, payload=None): - ''' + """ Send an unsolicted pong. :param payload: The pong's payload. If ``None``, then no payload is sent. :type payload: bytes or None :raises ConnectionClosed: if connection is closed - ''' + """ if self._close_reason: raise ConnectionClosed(self._close_reason) await self._send(Pong(payload=payload)) async def send_message(self, message): - ''' + """ Send a WebSocket message. :param message: The message to send. :type message: str or bytes :raises ConnectionClosed: if connection is closed, or being closed - ''' + """ if self._close_reason: raise ConnectionClosed(self._close_reason) if isinstance(message, str): @@ -948,16 +1075,16 @@ async def send_message(self, message): elif isinstance(message, bytes): event = BytesMessage(data=message) else: - raise ValueError('message must be str or bytes') + raise ValueError("message must be str or bytes") await self._send(event) def __str__(self): - ''' Connection ID and type. ''' - type_ = 'client' if self.is_client else 'server' - return f'{type_}-{self._id}' + """Connection ID and type.""" + type_ = "client" if self.is_client else "server" + return f"{type_}-{self._id}" async def _accept(self, request, subprotocol, extra_headers): - ''' + """ Accept the handshake. This method is only applicable to server-side connections. @@ -967,15 +1094,16 @@ async def _accept(self, request, subprotocol, extra_headers): :type subprotocol: str or None :param list[tuple[bytes,bytes]] extra_headers: A list of 2-tuples containing key/value pairs to send as HTTP headers. - ''' + """ self._subprotocol = subprotocol self._path = request.target - await self._send(AcceptConnection(subprotocol=self._subprotocol, - extra_headers=extra_headers)) + await self._send( + AcceptConnection(subprotocol=self._subprotocol, extra_headers=extra_headers) + ) self._open_handshake.set() async def _reject(self, status_code, headers, body): - ''' + """ Reject the handshake. :param int status_code: The 3 digit HTTP status code. In order to be @@ -984,25 +1112,26 @@ async def _reject(self, status_code, headers, body): :param list[tuple[bytes,bytes]] headers: A list of 2-tuples containing key/value pairs to send as HTTP headers. :param bytes body: An optional response body. - ''' + """ if body: - headers.append(('Content-length', str(len(body)).encode('ascii'))) - reject_conn = RejectConnection(status_code=status_code, headers=headers, - has_body=bool(body)) + headers.append(("Content-length", str(len(body)).encode("ascii"))) + reject_conn = RejectConnection( + status_code=status_code, headers=headers, has_body=bool(body) + ) await self._send(reject_conn) if body: reject_body = RejectData(data=body) await self._send(reject_body) - self._close_reason = CloseReason(1006, 'Rejected WebSocket handshake') + self._close_reason = CloseReason(1006, "Rejected WebSocket handshake") self._close_handshake.set() async def _abort_web_socket(self): - ''' + """ If a stream is closed outside of this class, e.g. due to network conditions or because some other code closed our stream object, then we cannot perform the close handshake. We just need to clean up internal state. - ''' + """ close_reason = wsframeproto.CloseReason.ABNORMAL_CLOSURE if self._wsproto.state == ConnectionState.OPEN: self._wsproto.send(CloseConnection(code=close_reason.value)) @@ -1014,7 +1143,7 @@ async def _abort_web_socket(self): self._close_handshake.set() async def _close_stream(self): - ''' Close the TCP connection. ''' + """Close the TCP connection.""" self._reader_running = False try: with _preserve_current_exception(): @@ -1024,85 +1153,89 @@ async def _close_stream(self): pass async def _close_web_socket(self, code, reason=None): - ''' + """ Mark the WebSocket as closed. Close the message channel so that if any tasks are suspended in get_message(), they will wake up with a ConnectionClosed exception. - ''' + """ self._close_reason = CloseReason(code, reason) exc = ConnectionClosed(self._close_reason) - logger.debug('%s websocket closed %r', self, exc) + logger.debug("%s websocket closed %r", self, exc) await self._send_channel.aclose() async def _get_request(self): - ''' + """ Return a proposal for a WebSocket handshake. This method can only be called on server connections and it may only be called one time. :rtype: WebSocketRequest - ''' + """ if not self.is_server: - raise RuntimeError('This method is only valid for server connections.') + raise RuntimeError("This method is only valid for server connections.") if self._connection_proposal is None: - raise RuntimeError('No proposal available. Did you call this method' - ' multiple times or at the wrong time?') + raise RuntimeError( + "No proposal available. Did you call this method" + " multiple times or at the wrong time?" + ) proposal = await self._connection_proposal.wait_value() self._connection_proposal = None return proposal async def _handle_request_event(self, event): - ''' + """ Handle a connection request. This method is async even though it never awaits, because the event dispatch requires an async function. :param event: - ''' + """ proposal = WebSocketRequest(self, event) self._connection_proposal.set_value(proposal) async def _handle_accept_connection_event(self, event): - ''' + """ Handle an AcceptConnection event. :param wsproto.eventsAcceptConnection event: - ''' + """ self._subprotocol = event.subprotocol self._handshake_headers = tuple(event.extra_headers) self._open_handshake.set() async def _handle_reject_connection_event(self, event): - ''' + """ Handle a RejectConnection event. :param event: - ''' + """ self._reject_status = event.status_code self._reject_headers = tuple(event.headers) if not event.has_body: - raise ConnectionRejected(self._reject_status, self._reject_headers, - body=None) + raise ConnectionRejected( + self._reject_status, self._reject_headers, body=None + ) async def _handle_reject_data_event(self, event): - ''' + """ Handle a RejectData event. :param event: - ''' + """ self._reject_body += event.data if event.body_finished: - raise ConnectionRejected(self._reject_status, self._reject_headers, - body=self._reject_body) + raise ConnectionRejected( + self._reject_status, self._reject_headers, body=self._reject_body + ) async def _handle_close_connection_event(self, event): - ''' + """ Handle a close event. :param wsproto.events.CloseConnection event: - ''' + """ if self._wsproto.state == ConnectionState.REMOTE_CLOSING: # Set _close_reason in advance, so that send_message() will raise # ConnectionClosed during the close handshake. @@ -1119,16 +1252,16 @@ async def _handle_close_connection_event(self, event): await self._close_stream() async def _handle_message_event(self, event): - ''' + """ Handle a message event. :param event: :type event: wsproto.events.BytesMessage or wsproto.events.TextMessage - ''' + """ self._message_size += len(event.data) self._message_parts.append(event.data) if self._message_size > self._max_message_size: - err = f'Exceeded maximum message size: {self._max_message_size} bytes' + err = f"Exceeded maximum message size: {self._max_message_size} bytes" self._message_size = 0 self._message_parts = [] self._close_reason = CloseReason(1009, err) @@ -1136,8 +1269,9 @@ async def _handle_message_event(self, event): await self._recv_channel.aclose() self._reader_running = False elif event.message_finished: - msg = (b'' if isinstance(event, BytesMessage) else '') \ - .join(self._message_parts) + msg = (b"" if isinstance(event, BytesMessage) else "").join( + self._message_parts + ) self._message_size = 0 self._message_parts = [] try: @@ -1149,19 +1283,19 @@ async def _handle_message_event(self, event): pass async def _handle_ping_event(self, event): - ''' + """ Handle a PingReceived event. Wsproto queues a pong frame automatically, so this handler just needs to send it. :param wsproto.events.Ping event: - ''' - logger.debug('%s ping %r', self, event.payload) + """ + logger.debug("%s ping %r", self, event.payload) await self._send(event.response()) async def _handle_pong_event(self, event): - ''' + """ Handle a PongReceived event. When a pong is received, check if we have any ping requests waiting for @@ -1173,7 +1307,7 @@ async def _handle_pong_event(self, event): complicated if some handlers were sync. :param event: - ''' + """ payload = bytes(event.payload) try: event = self._pings[payload] @@ -1183,14 +1317,14 @@ async def _handle_pong_event(self, event): return while self._pings: key, event = self._pings.popitem(0) - skipped = ' [skipped] ' if payload != key else ' ' - logger.debug('%s pong%s%r', self, skipped, key) + skipped = " [skipped] " if payload != key else " " + logger.debug("%s pong%s%r", self, skipped, key) event.set() if payload == key: break async def _reader_task(self): - ''' A background task that reads network data and generates events. ''' + """A background task that reads network data and generates events.""" handlers = { AcceptConnection: self._handle_accept_connection_event, BytesMessage: self._handle_message_event, @@ -1216,12 +1350,12 @@ async def _reader_task(self): event_type = type(event) try: handler = handlers[event_type] - logger.debug('%s received event: %s', self, - event_type) + logger.debug("%s received event: %s", self, event_type) await handler(event) except KeyError: - logger.warning('%s received unknown event type: "%s"', self, - event_type) + logger.warning( + '%s received unknown event type: "%s"', self, event_type + ) except ConnectionClosed: self._reader_running = False break @@ -1233,27 +1367,26 @@ async def _reader_task(self): await self._abort_web_socket() break if len(data) == 0: - logger.debug('%s received zero bytes (connection closed)', - self) + logger.debug("%s received zero bytes (connection closed)", self) # If TCP closed before WebSocket, then record it as an abnormal # closure. if self._wsproto.state != ConnectionState.CLOSED: await self._abort_web_socket() break - logger.debug('%s received %d bytes', self, len(data)) + logger.debug("%s received %d bytes", self, len(data)) if self._wsproto.state != ConnectionState.CLOSED: try: self._wsproto.receive_data(data) except wsproto.utilities.RemoteProtocolError as err: - logger.debug('%s remote protocol error: %s', self, err) + logger.debug("%s remote protocol error: %s", self, err) if err.event_hint: await self._send(err.event_hint) await self._close_stream() - logger.debug('%s reader task finished', self) + logger.debug("%s reader task finished", self) async def _send(self, event): - ''' + """ Send an event to the remote WebSocket. The reader task and one or more writers might try to send messages at @@ -1261,10 +1394,10 @@ async def _send(self, event): requests to send data. :param wsproto.events.Event event: - ''' + """ data = self._wsproto.send(event) async with self._stream_lock: - logger.debug('%s sending %d bytes', self, len(data)) + logger.debug("%s sending %d bytes", self, len(data)) try: await self._stream.send_all(data) except (trio.BrokenResourceError, trio.ClosedResourceError): @@ -1273,7 +1406,8 @@ async def _send(self, event): class Endpoint: - ''' Represents a connection endpoint. ''' + """Represents a connection endpoint.""" + def __init__(self, address, port, is_ssl): #: IP address :class:`ipaddress.ip_address` self.address = ip_address(address) @@ -1284,37 +1418,43 @@ def __init__(self, address, port, is_ssl): @property def url(self): - ''' Return a URL representation of a TCP endpoint, e.g. - ``ws://127.0.0.1:80``. ''' - scheme = 'wss' if self.is_ssl else 'ws' - if (self.port == 80 and not self.is_ssl) or \ - (self.port == 443 and self.is_ssl): - port_str = '' + """Return a URL representation of a TCP endpoint, e.g. + ``ws://127.0.0.1:80``.""" + scheme = "wss" if self.is_ssl else "ws" + if (self.port == 80 and not self.is_ssl) or (self.port == 443 and self.is_ssl): + port_str = "" else: - port_str = ':' + str(self.port) + port_str = ":" + str(self.port) if self.address.version == 4: - return f'{scheme}://{self.address}{port_str}' - return f'{scheme}://[{self.address}]{port_str}' + return f"{scheme}://{self.address}{port_str}" + return f"{scheme}://[{self.address}]{port_str}" def __repr__(self): - ''' Return endpoint info as string. ''' + """Return endpoint info as string.""" return f'Endpoint(address="{self.address}", port={self.port}, is_ssl={self.is_ssl})' class WebSocketServer: - ''' + """ WebSocket server. The server class handles incoming connections on one or more ``Listener`` objects. For each incoming connection, it creates a ``WebSocketConnection`` instance and starts some background tasks, - ''' + """ - def __init__(self, handler, listeners, *, handler_nursery=None, + def __init__( + self, + handler, + listeners, + *, + handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, - disconnect_timeout=CONN_TIMEOUT): - ''' + max_message_size=MAX_MESSAGE_SIZE, + connect_timeout=CONN_TIMEOUT, + disconnect_timeout=CONN_TIMEOUT, + ): + """ Constructor. Note that if ``host`` is ``None`` and ``port`` is zero, then you may get @@ -1333,9 +1473,9 @@ def __init__(self, handler, listeners, *, handler_nursery=None, to finish connection handshake before timing out. :param float disconnect_timeout: The number of seconds to wait for a client to finish the closing handshake before timing out. - ''' + """ if len(listeners) == 0: - raise ValueError('Listeners must contain at least one item.') + raise ValueError("Listeners must contain at least one item.") self._handler = handler self._handler_nursery = handler_nursery self._listeners = listeners @@ -1357,24 +1497,27 @@ def port(self): listener must be socket-based. """ if len(self._listeners) > 1: - raise RuntimeError('Cannot get port because this server has' - ' more than 1 listeners.') + raise RuntimeError( + "Cannot get port because this server has" " more than 1 listeners." + ) listener = self.listeners[0] try: return listener.port except AttributeError: - raise RuntimeError(f'This socket does not have a port: {repr(listener)}') from None + raise RuntimeError( + f"This socket does not have a port: {repr(listener)}" + ) from None @property def listeners(self): - ''' + """ Return a list of listener metadata. Each TCP listener is represented as an ``Endpoint`` instance. Other listener types are represented by their ``repr()``. :returns: Listeners :rtype list[Endpoint or str]: - ''' + """ listeners = [] for listener in self._listeners: socket, is_ssl = None, False @@ -1391,7 +1534,7 @@ def listeners(self): return listeners async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): - ''' + """ Start serving incoming connections requests. This method supports the Trio nursery start protocol: ``server = await @@ -1400,30 +1543,34 @@ async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): :param task_status: Part of the Trio nursery start protocol. :returns: This method never returns unless cancelled. - ''' + """ async with trio.open_nursery() as nursery: - serve_listeners = partial(trio.serve_listeners, - self._handle_connection, self._listeners, - handler_nursery=self._handler_nursery) + serve_listeners = partial( + trio.serve_listeners, + self._handle_connection, + self._listeners, + handler_nursery=self._handler_nursery, + ) await nursery.start(serve_listeners) - logger.debug('Listening on %s', - ','.join([str(l) for l in self.listeners])) + logger.debug("Listening on %s", ",".join([str(l) for l in self.listeners])) task_status.started(self) await trio.sleep_forever() async def _handle_connection(self, stream): - ''' + """ Handle an incoming connection by spawning a connection background task and a handler task inside a new nursery. :param stream: :type stream: trio.abc.Stream - ''' + """ async with trio.open_nursery() as nursery: - connection = WebSocketConnection(stream, + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.SERVER), message_queue_size=self._message_queue_size, - max_message_size=self._max_message_size) + max_message_size=self._max_message_size, + ) nursery.start_soon(connection._reader_task) with trio.move_on_after(self._connect_timeout) as connect_scope: request = await connection._get_request() diff --git a/trio_websocket/_version.py b/trio_websocket/_version.py index d1109e1..ce724d8 100644 --- a/trio_websocket/_version.py +++ b/trio_websocket/_version.py @@ -1 +1 @@ -__version__ = '0.11.0-dev' +__version__ = "0.11.0-dev" From 35c82c367770262048de8bbe487273ceb26c7393 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Wed, 10 May 2023 14:11:47 -0700 Subject: [PATCH 02/37] Fix formatting on a string constant --- trio_websocket/_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index e774229..89783f0 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -1498,7 +1498,7 @@ def port(self): """ if len(self._listeners) > 1: raise RuntimeError( - "Cannot get port because this server has" " more than 1 listeners." + "Cannot get port because this server has more than 1 listeners." ) listener = self.listeners[0] try: From 65ef37639347b80cfa9bab798343a9f01dfcc769 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Wed, 10 May 2023 14:12:02 -0700 Subject: [PATCH 03/37] Work around spurious pylint error trio.MultiError is deprecated, and for technical reasons involving how the deprecation is implemented, this means pylint can't see it and thinks it doesn't exist. It does exist. --- trio_websocket/_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 89783f0..718af7a 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -67,7 +67,7 @@ def __exit__(self, ty, value, tb): return False if _TRIO_MULTI_ERROR: - filtered_exception = trio.MultiError.filter( + filtered_exception = trio.MultiError.filter( # pylint: disable=no-member _ignore_cancel, value ) # pylint: disable=no-member elif isinstance(value, BaseExceptionGroup): From 2970fb5d429b6e9d234bf0920ec8396707239fc7 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 26 Oct 2023 18:38:41 +0200 Subject: [PATCH 04/37] add black to 'make lint', and run it on all files --- Makefile | 1 + autobahn/client.py | 55 ++++++++++++++++++------------ autobahn/server.py | 41 ++++++++++++---------- examples/client.py | 71 ++++++++++++++++++++------------------- examples/generate-cert.py | 19 ++++++----- examples/server.py | 41 ++++++++++++---------- requirements-extras.in | 1 + tests/test_connection.py | 3 +- trio_websocket/_impl.py | 60 ++++++--------------------------- 9 files changed, 138 insertions(+), 154 deletions(-) diff --git a/Makefile b/Makefile index 2efced1..ff64f05 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,7 @@ test: $(PYTHON) -m pytest --cov=trio_websocket --no-cov-on-fail lint: + $(PYTHON) -m black trio_websocket/ tests/ autobahn/ examples/ $(PYTHON) -m pylint trio_websocket/ tests/ autobahn/ examples/ publish: diff --git a/autobahn/client.py b/autobahn/client.py index d93be1c..1537009 100644 --- a/autobahn/client.py +++ b/autobahn/client.py @@ -1,7 +1,7 @@ -''' +""" This test client runs against the Autobahn test server. It is based on the test_client.py in wsproto. -''' +""" import argparse import json import logging @@ -11,28 +11,28 @@ from trio_websocket import open_websocket_url, ConnectionClosed -AGENT = 'trio-websocket' +AGENT = "trio-websocket" MAX_MESSAGE_SIZE = 16 * 1024 * 1024 logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('client') +logger = logging.getLogger("client") async def get_case_count(url): - url = url + '/getCaseCount' + url = url + "/getCaseCount" async with open_websocket_url(url) as conn: case_count = await conn.get_message() - logger.info('Case count=%s', case_count) + logger.info("Case count=%s", case_count) return int(case_count) async def get_case_info(url, case): - url = f'{url}/getCaseInfo?case={case}' + url = f"{url}/getCaseInfo?case={case}" async with open_websocket_url(url) as conn: return json.loads(await conn.get_message()) async def run_case(url, case): - url = f'{url}/runCase?case={case}&agent={AGENT}' + url = f"{url}/runCase?case={case}&agent={AGENT}" try: async with open_websocket_url(url, max_message_size=MAX_MESSAGE_SIZE) as conn: while True: @@ -43,7 +43,7 @@ async def run_case(url, case): async def update_reports(url): - url = f'{url}/updateReports?agent={AGENT}' + url = f"{url}/updateReports?agent={AGENT}" async with open_websocket_url(url) as conn: # This command runs as soon as we connect to it, so we don't need to # send any messages. @@ -51,7 +51,7 @@ async def update_reports(url): async def run_tests(args): - logger = logging.getLogger('trio-websocket') + logger = logging.getLogger("trio-websocket") if args.debug_cases: # Don't fetch case count when debugging a subset of test cases. It adds # noise to the debug logging. @@ -62,7 +62,7 @@ async def run_tests(args): test_cases = list(range(1, case_count + 1)) exception_cases = [] for case in test_cases: - case_id = (await get_case_info(args.url, case))['id'] + case_id = (await get_case_info(args.url, case))["id"] if case_count: logger.info("Running test case %s (%d of %d)", case_id, case, case_count) else: @@ -71,28 +71,39 @@ async def run_tests(args): try: await run_case(args.url, case) except Exception: # pylint: disable=broad-exception-caught - logger.exception(' runtime exception during test case %s (%d)', case_id, case) + logger.exception( + " runtime exception during test case %s (%d)", case_id, case + ) exception_cases.append(case_id) logger.setLevel(logging.INFO) - logger.info('Updating report') + logger.info("Updating report") await update_reports(args.url) if exception_cases: - logger.error('Runtime exception in %d of %d test cases: %s', - len(exception_cases), len(test_cases), exception_cases) + logger.error( + "Runtime exception in %d of %d test cases: %s", + len(exception_cases), + len(test_cases), + exception_cases, + ) sys.exit(1) def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Autobahn client for' - ' trio-websocket') - parser.add_argument('url', help='WebSocket URL for server') + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Autobahn client for" " trio-websocket" + ) + parser.add_argument("url", help="WebSocket URL for server") # TODO: accept case ID's rather than indices - parser.add_argument('debug_cases', type=int, nargs='*', help='Run' - ' individual test cases with debug logging (optional)') + parser.add_argument( + "debug_cases", + type=int, + nargs="*", + help="Run" " individual test cases with debug logging (optional)", + ) return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() trio.run(run_tests, args) diff --git a/autobahn/server.py b/autobahn/server.py index ff23846..9941445 100644 --- a/autobahn/server.py +++ b/autobahn/server.py @@ -1,4 +1,4 @@ -''' +""" This simple WebSocket server responds to text messages by reversing each message string and sending it back. @@ -7,34 +7,35 @@ To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. -''' +""" import argparse import logging import trio from trio_websocket import serve_websocket, ConnectionClosed, WebSocketRequest -BIND_IP = '0.0.0.0' +BIND_IP = "0.0.0.0" BIND_PORT = 9000 MAX_MESSAGE_SIZE = 16 * 1024 * 1024 logging.basicConfig() -logger = logging.getLogger('client') +logger = logging.getLogger("client") logger.setLevel(logging.INFO) connection_count = 0 async def main(): - ''' Main entry point. ''' - logger.info('Starting websocket server on ws://%s:%d', BIND_IP, BIND_PORT) - await serve_websocket(handler, BIND_IP, BIND_PORT, ssl_context=None, - max_message_size=MAX_MESSAGE_SIZE) + """Main entry point.""" + logger.info("Starting websocket server on ws://%s:%d", BIND_IP, BIND_PORT) + await serve_websocket( + handler, BIND_IP, BIND_PORT, ssl_context=None, max_message_size=MAX_MESSAGE_SIZE + ) async def handler(request: WebSocketRequest): - ''' Reverse incoming websocket messages and send them back. ''' + """Reverse incoming websocket messages and send them back.""" global connection_count # pylint: disable=global-statement connection_count += 1 - logger.info('Connection #%d', connection_count) + logger.info("Connection #%d", connection_count) ws = await request.accept() while True: try: @@ -43,20 +44,24 @@ async def handler(request: WebSocketRequest): except ConnectionClosed: break except Exception: # pylint: disable=broad-exception-caught - logger.exception(' runtime exception handling connection #%d', connection_count) + logger.exception( + " runtime exception handling connection #%d", connection_count + ) def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Autobahn server for' - ' trio-websocket') - parser.add_argument('-d', '--debug', action='store_true', - help='WebSocket URL for server') + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Autobahn server for" " trio-websocket" + ) + parser.add_argument( + "-d", "--debug", action="store_true", help="WebSocket URL for server" + ) return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() if args.debug: - logging.getLogger('trio-websocket').setLevel(logging.DEBUG) + logging.getLogger("trio-websocket").setLevel(logging.DEBUG) trio.run(main) diff --git a/examples/client.py b/examples/client.py index 030c12b..c17a830 100644 --- a/examples/client.py +++ b/examples/client.py @@ -1,10 +1,10 @@ -''' +""" This interactive WebSocket client allows the user to send frames to a WebSocket server, including text message, ping, and close frames. To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. -''' +""" import argparse import logging import pathlib @@ -21,49 +21,51 @@ def commands(): - ''' Print the supported commands. ''' - print('Commands: ') - print('send -> send message') - print('ping -> send ping with payload') - print('close [] -> politely close connection with optional reason') + """Print the supported commands.""" + print("Commands: ") + print("send -> send message") + print("ping -> send ping with payload") + print("close [] -> politely close connection with optional reason") print() def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Example trio-websocket client') - parser.add_argument('--heartbeat', action='store_true', - help='Create a heartbeat task') - parser.add_argument('url', help='WebSocket URL to connect to') + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Example trio-websocket client") + parser.add_argument( + "--heartbeat", action="store_true", help="Create a heartbeat task" + ) + parser.add_argument("url", help="WebSocket URL to connect to") return parser.parse_args() async def main(args): - ''' Main entry point, returning False in the case of logged error. ''' - if urllib.parse.urlsplit(args.url).scheme == 'wss': + """Main entry point, returning False in the case of logged error.""" + if urllib.parse.urlsplit(args.url).scheme == "wss": # Configure SSL context to handle our self-signed certificate. Most # clients won't need to do this. try: ssl_context = ssl.create_default_context() - ssl_context.load_verify_locations(here / 'fake.ca.pem') + ssl_context.load_verify_locations(here / "fake.ca.pem") except FileNotFoundError: - logging.error('Did not find file "fake.ca.pem". You need to run' - ' generate-cert.py') + logging.error( + 'Did not find file "fake.ca.pem". You need to run generate-cert.py' + ) return False else: ssl_context = None try: - logging.debug('Connecting to WebSocket…') + logging.debug("Connecting to WebSocket…") async with open_websocket_url(args.url, ssl_context) as conn: await handle_connection(conn, args.heartbeat) except HandshakeError as e: - logging.error('Connection attempt failed: %s', e) + logging.error("Connection attempt failed: %s", e) return False async def handle_connection(ws, use_heartbeat): - ''' Handle the connection. ''' - logging.debug('Connected!') + """Handle the connection.""" + logging.debug("Connected!") try: async with trio.open_nursery() as nursery: if use_heartbeat: @@ -71,12 +73,12 @@ async def handle_connection(ws, use_heartbeat): nursery.start_soon(get_commands, ws) nursery.start_soon(get_messages, ws) except ConnectionClosed as cc: - reason = '' if cc.reason.reason is None else f'"{cc.reason.reason}"' - print(f'Closed: {cc.reason.code}/{cc.reason.name} {reason}') + reason = "" if cc.reason.reason is None else f'"{cc.reason.reason}"' + print(f"Closed: {cc.reason.code}/{cc.reason.name} {reason}") async def heartbeat(ws, timeout, interval): - ''' + """ Send periodic pings on WebSocket ``ws``. Wait up to ``timeout`` seconds to send a ping and receive a pong. Raises @@ -92,7 +94,7 @@ async def heartbeat(ws, timeout, interval): :raises: ``ConnectionClosed`` if ``ws`` is closed. :raises: ``TooSlowError`` if the timeout expires. :returns: This function runs until cancelled. - ''' + """ while True: with trio.fail_after(timeout): await ws.ping() @@ -100,20 +102,19 @@ async def heartbeat(ws, timeout, interval): async def get_commands(ws): - ''' In a loop: get a command from the user and execute it. ''' + """In a loop: get a command from the user and execute it.""" while True: - cmd = await trio.to_thread.run_sync(input, 'cmd> ', - cancellable=True) - if cmd.startswith('ping'): - payload = cmd[5:].encode('utf8') or None + cmd = await trio.to_thread.run_sync(input, "cmd> ", cancellable=True) + if cmd.startswith("ping"): + payload = cmd[5:].encode("utf8") or None await ws.ping(payload) - elif cmd.startswith('send'): + elif cmd.startswith("send"): message = cmd[5:] or None if message is None: logging.error('The "send" command requires a message.') else: await ws.send_message(message) - elif cmd.startswith('close'): + elif cmd.startswith("close"): reason = cmd[6:] or None await ws.aclose(code=1000, reason=reason) break @@ -124,13 +125,13 @@ async def get_commands(ws): async def get_messages(ws): - ''' In a loop: get a WebSocket message and print it out. ''' + """In a loop: get a WebSocket message and print it out.""" while True: message = await ws.get_message() - print(f'message: {message}') + print(f"message: {message}") -if __name__ == '__main__': +if __name__ == "__main__": try: if not trio.run(main, parse_args()): sys.exit(1) diff --git a/examples/generate-cert.py b/examples/generate-cert.py index cc21698..4f0e6ff 100644 --- a/examples/generate-cert.py +++ b/examples/generate-cert.py @@ -3,22 +3,23 @@ import trustme + def main(): here = pathlib.Path(__file__).parent - ca_path = here / 'fake.ca.pem' - server_path = here / 'fake.server.pem' + ca_path = here / "fake.ca.pem" + server_path = here / "fake.server.pem" if ca_path.exists() and server_path.exists(): - print('The CA ceritificate and server certificate already exist.') + print("The CA ceritificate and server certificate already exist.") sys.exit(1) - print('Creating self-signed certificate for localhost/127.0.0.1:') + print("Creating self-signed certificate for localhost/127.0.0.1:") ca_cert = trustme.CA() ca_cert.cert_pem.write_to_path(ca_path) - print(f' * CA certificate: {ca_path}') - server_cert = ca_cert.issue_server_cert('localhost', '127.0.0.1') + print(f" * CA certificate: {ca_path}") + server_cert = ca_cert.issue_server_cert("localhost", "127.0.0.1") server_cert.private_key_and_cert_chain_pem.write_to_path(server_path) - print(f' * Server certificate: {server_path}') - print('Done') + print(f" * Server certificate: {server_path}") + print("Done") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/server.py b/examples/server.py index 611d89b..e77afb0 100644 --- a/examples/server.py +++ b/examples/server.py @@ -1,4 +1,4 @@ -''' +""" This simple WebSocket server responds to text messages by reversing each message string and sending it back. @@ -7,7 +7,7 @@ To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. -''' +""" import argparse import logging import pathlib @@ -23,33 +23,38 @@ def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Example trio-websocket client') - parser.add_argument('--ssl', action='store_true', help='Use SSL') - parser.add_argument('host', help='Host interface to bind. If omitted, ' - 'then bind all interfaces.', nargs='?') - parser.add_argument('port', type=int, help='Port to bind.') + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Example trio-websocket client") + parser.add_argument("--ssl", action="store_true", help="Use SSL") + parser.add_argument( + "host", + help="Host interface to bind. If omitted, " "then bind all interfaces.", + nargs="?", + ) + parser.add_argument("port", type=int, help="Port to bind.") return parser.parse_args() async def main(args): - ''' Main entry point. ''' - logging.info('Starting websocket server…') + """Main entry point.""" + logging.info("Starting websocket server…") if args.ssl: ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) try: - ssl_context.load_cert_chain(here / 'fake.server.pem') + ssl_context.load_cert_chain(here / "fake.server.pem") except FileNotFoundError: - logging.error('Did not find file "fake.server.pem". You need to run' - ' generate-cert.py') + logging.error( + 'Did not find file "fake.server.pem". You need to run' + " generate-cert.py" + ) else: ssl_context = None - host = None if args.host == '*' else args.host + host = None if args.host == "*" else args.host await serve_websocket(handler, host, args.port, ssl_context) async def handler(request): - ''' Reverse incoming websocket messages and send them back. ''' + """Reverse incoming websocket messages and send them back.""" logging.info('Handler starting on path "%s"', request.path) ws = await request.accept() while True: @@ -57,12 +62,12 @@ async def handler(request): message = await ws.get_message() await ws.send_message(message[::-1]) except ConnectionClosed: - logging.info('Connection closed') + logging.info("Connection closed") break - logging.info('Handler exiting') + logging.info("Handler exiting") -if __name__ == '__main__': +if __name__ == "__main__": try: trio.run(main, parse_args()) except KeyboardInterrupt: diff --git a/requirements-extras.in b/requirements-extras.in index 9f4d0c5..1abe99b 100644 --- a/requirements-extras.in +++ b/requirements-extras.in @@ -1,4 +1,5 @@ # requirements for `make lint/docs/publish` +black pylint sphinx sphinxcontrib-trio diff --git a/tests/test_connection.py b/tests/test_connection.py index d608375..f101878 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1112,7 +1112,7 @@ async def test_remote_close_rude(): async def client(): client_conn = await wrap_client_stream(nursery, client_stream, HOST, RESOURCE) assert not client_conn.closed - await client_conn.send_message('Hello from client!') + await client_conn.send_message("Hello from client!") with pytest.raises(ConnectionClosed): await client_conn.get_message() @@ -1131,7 +1131,6 @@ async def server(): # pump the messages over memory_stream_pump(server_stream.send_stream, client_stream.receive_stream) - async with trio.open_nursery() as nursery: nursery.start_soon(server) nursery.start_soon(client) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 7cb9865..b577251 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -13,7 +13,6 @@ import trio import trio.abc -from exceptiongroup import BaseExceptionGroup from wsproto import ConnectionType, WSConnection from wsproto.connection import ConnectionState import wsproto.frame_protocol as wsframeproto @@ -30,6 +29,9 @@ ) import wsproto.utilities +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin + _TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split(".")[:2])) < (0, 22) CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds @@ -1344,46 +1346,6 @@ async def _reader_task(self): except ConnectionClosed: self._reader_running = False -<<<<<<< HEAD - while self._reader_running: - # Process events. - for event in self._wsproto.events(): - event_type = type(event) - try: - handler = handlers[event_type] - logger.debug("%s received event: %s", self, event_type) - await handler(event) - except KeyError: - logger.warning( - '%s received unknown event type: "%s"', self, event_type - ) - except ConnectionClosed: - self._reader_running = False - break - - # Get network data. - try: - data = await self._stream.receive_some(RECEIVE_BYTES) - except (trio.BrokenResourceError, trio.ClosedResourceError): - await self._abort_web_socket() - break - if len(data) == 0: - logger.debug("%s received zero bytes (connection closed)", self) - # If TCP closed before WebSocket, then record it as an abnormal - # closure. - if self._wsproto.state != ConnectionState.CLOSED: - await self._abort_web_socket() - break - logger.debug("%s received %d bytes", self, len(data)) - if self._wsproto.state != ConnectionState.CLOSED: - try: - self._wsproto.receive_data(data) - except wsproto.utilities.RemoteProtocolError as err: - logger.debug("%s remote protocol error: %s", self, err) - if err.event_hint: - await self._send(err.event_hint) - await self._close_stream() -======= async with self._send_channel: while self._reader_running: # Process events. @@ -1391,12 +1353,12 @@ async def _reader_task(self): event_type = type(event) try: handler = handlers[event_type] - logger.debug('%s received event: %s', self, - event_type) + logger.debug("%s received event: %s", self, event_type) await handler(event) except KeyError: - logger.warning('%s received unknown event type: "%s"', self, - event_type) + logger.warning( + '%s received unknown event type: "%s"', self, event_type + ) except ConnectionClosed: self._reader_running = False break @@ -1408,23 +1370,21 @@ async def _reader_task(self): await self._abort_web_socket() break if len(data) == 0: - logger.debug('%s received zero bytes (connection closed)', - self) + logger.debug("%s received zero bytes (connection closed)", self) # If TCP closed before WebSocket, then record it as an abnormal # closure. if self._wsproto.state != ConnectionState.CLOSED: await self._abort_web_socket() break - logger.debug('%s received %d bytes', self, len(data)) + logger.debug("%s received %d bytes", self, len(data)) if self._wsproto.state != ConnectionState.CLOSED: try: self._wsproto.receive_data(data) except wsproto.utilities.RemoteProtocolError as err: - logger.debug('%s remote protocol error: %s', self, err) + logger.debug("%s remote protocol error: %s", self, err) if err.event_hint: await self._send(err.event_hint) await self._close_stream() ->>>>>>> origin/HEAD logger.debug("%s reader task finished", self) From 903dfc35a23c64e2f5d03565f48a076e9b6f667c Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 13 Jun 2024 10:42:01 -0500 Subject: [PATCH 05/37] Run new version of black on all files --- autobahn/client.py | 1 + autobahn/server.py | 1 + examples/client.py | 1 + examples/server.py | 1 + setup.py | 52 +++++++++++++++++++------------------- tests/test_connection.py | 3 ++- trio_websocket/_version.py | 2 +- 7 files changed, 33 insertions(+), 28 deletions(-) diff --git a/autobahn/client.py b/autobahn/client.py index 1537009..dc0e890 100644 --- a/autobahn/client.py +++ b/autobahn/client.py @@ -2,6 +2,7 @@ This test client runs against the Autobahn test server. It is based on the test_client.py in wsproto. """ + import argparse import json import logging diff --git a/autobahn/server.py b/autobahn/server.py index 9941445..5263306 100644 --- a/autobahn/server.py +++ b/autobahn/server.py @@ -8,6 +8,7 @@ To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. """ + import argparse import logging diff --git a/examples/client.py b/examples/client.py index c17a830..08610cd 100644 --- a/examples/client.py +++ b/examples/client.py @@ -5,6 +5,7 @@ To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. """ + import argparse import logging import pathlib diff --git a/examples/server.py b/examples/server.py index e77afb0..0bcca25 100644 --- a/examples/server.py +++ b/examples/server.py @@ -8,6 +8,7 @@ To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. """ + import argparse import logging import pathlib diff --git a/setup.py b/setup.py index a5040a6..b743461 100644 --- a/setup.py +++ b/setup.py @@ -10,43 +10,43 @@ # Get description -with (here / 'README.md').open(encoding='utf-8') as f: +with (here / "README.md").open(encoding="utf-8") as f: long_description = f.read() setup( - name='trio-websocket', - version=version['__version__'], - description='WebSocket library for Trio', + name="trio-websocket", + version=version["__version__"], + description="WebSocket library for Trio", long_description=long_description, - long_description_content_type='text/markdown', - url='https://github.com/python-trio/trio-websocket', - author='Mark E. Haase', - author_email='mehaase@gmail.com', + long_description_content_type="text/markdown", + url="https://github.com/python-trio/trio-websocket", + author="Mark E. Haase", + author_email="mehaase@gmail.com", classifiers=[ # See https://pypi.org/classifiers/ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Libraries', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Software Development :: Libraries", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", ], python_requires=">=3.7", - keywords='websocket client server trio', - packages=find_packages(exclude=['docs', 'examples', 'tests']), + keywords="websocket client server trio", + packages=find_packages(exclude=["docs", "examples", "tests"]), install_requires=[ 'exceptiongroup; python_version<"3.11"', - 'trio>=0.11', - 'wsproto>=0.14', + "trio>=0.11", + "wsproto>=0.14", ], project_urls={ - 'Bug Reports': 'https://github.com/python-trio/trio-websocket/issues', - 'Source': 'https://github.com/python-trio/trio-websocket', + "Bug Reports": "https://github.com/python-trio/trio-websocket/issues", + "Source": "https://github.com/python-trio/trio-websocket", }, ) diff --git a/tests/test_connection.py b/tests/test_connection.py index f0e5434..a426109 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -29,6 +29,7 @@ the server to block until the client has sent the closing handshake. In other circumstances """ + from functools import partial, wraps import ssl from unittest.mock import patch @@ -422,7 +423,7 @@ async def handler(request): @fail_after(1) async def test_handshake_server_headers(nursery): async def handler(request): - headers = [('X-Test-Header', 'My test header')] + headers = [("X-Test-Header", "My test header")] server_ws = await request.accept(extra_headers=headers) server = await nursery.start(serve_websocket, handler, HOST, 0, None) diff --git a/trio_websocket/_version.py b/trio_websocket/_version.py index 2320701..5c47800 100644 --- a/trio_websocket/_version.py +++ b/trio_websocket/_version.py @@ -1 +1 @@ -__version__ = '0.12.0-dev' +__version__ = "0.12.0-dev" From d64908b15d35cd061084cb6ec5532cfb1eefc736 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 13 Jun 2024 10:42:39 -0500 Subject: [PATCH 06/37] Run black on docs config with manual spacing fixes --- docs/conf.py | 67 ++++++++++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 649051b..88a2596 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,11 +19,12 @@ # -- Project information ----------------------------------------------------- -project = 'Trio WebSocket' -copyright = '2018, Hyperion Gray' -author = 'Hyperion Gray' +project = "Trio WebSocket" +copyright = "2018, Hyperion Gray" +author = "Hyperion Gray" from trio_websocket._version import __version__ as version + release = version @@ -37,22 +38,22 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinxcontrib_trio', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinxcontrib_trio", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -64,7 +65,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None @@ -75,7 +76,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -86,7 +87,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -102,26 +103,22 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'TrioWebSocketdoc' +htmlhelp_basename = "TrioWebSocketdoc" # -- Options for LaTeX output ------------------------------------------------ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). - # # 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). # + # The font size ('10pt', '11pt' or '12pt'). # 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. # + # Additional stuff for the LaTeX preamble. # 'preamble': '', - - # Latex figure (float) alignment # + # Latex figure (float) alignment # 'figure_align': 'htbp', } @@ -129,8 +126,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'TrioWebSocket.tex', 'Trio WebSocket Documentation', - 'Hyperion Gray', 'manual'), + ( + master_doc, + "TrioWebSocket.tex", + "Trio WebSocket Documentation", + "Hyperion Gray", + "manual", + ), ] @@ -138,10 +140,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'triowebsocket', 'Trio WebSocket Documentation', - [author], 1) -] +man_pages = [(master_doc, "triowebsocket", "Trio WebSocket Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -150,9 +149,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'TrioWebSocket', 'Trio WebSocket Documentation', - author, 'TrioWebSocket', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "TrioWebSocket", + "Trio WebSocket Documentation", + author, + "TrioWebSocket", + "One line description of project.", + "Miscellaneous", + ), ] @@ -171,10 +176,10 @@ # epub_uid = '' # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # -- Extension configuration ------------------------------------------------- intersphinx_mapping = { - 'trio': ('https://trio.readthedocs.io/en/stable/', None), + "trio": ("https://trio.readthedocs.io/en/stable/", None), } From e7dd16b6ba28628ddb5f9fa941bdb0ff3ace3d3f Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 13 Jun 2024 10:46:34 -0500 Subject: [PATCH 07/37] Fix broken merge commit --- trio_websocket/_impl.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 753746f..7cbbd95 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -69,15 +69,10 @@ def __exit__(self, ty, value, tb): if value is None or not self._armed: return False -<<<<<<< HEAD - if _TRIO_MULTI_ERROR: - filtered_exception = trio.MultiError.filter( # pylint: disable=no-member + if _TRIO_MULTI_ERROR: # pragma: no cover + filtered_exception = trio.MultiError.filter( _ignore_cancel, value ) # pylint: disable=no-member -======= - if _TRIO_MULTI_ERROR: # pragma: no cover - filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # pylint: disable=no-member ->>>>>>> origin/master elif isinstance(value, BaseExceptionGroup): filtered_exception = value.subgroup( lambda exc: not isinstance(exc, trio.Cancelled) From 8fca0593a1d1bebc304e0b78a2c55ad76511971e Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 13 Jun 2024 10:49:57 -0500 Subject: [PATCH 08/37] Add midding black dev dependency --- requirements-dev-full.txt | 2 ++ requirements-dev.in | 1 + 2 files changed, 3 insertions(+) diff --git a/requirements-dev-full.txt b/requirements-dev-full.txt index 6ad3f76..c5b35d0 100644 --- a/requirements-dev-full.txt +++ b/requirements-dev-full.txt @@ -17,6 +17,8 @@ attrs==22.2.0 # trio babel==2.12.1 # via sphinx +black==24.4.2 + # via -r requirements-dev.in bleach==6.0.0 # via readme-renderer build==0.10.0 diff --git a/requirements-dev.in b/requirements-dev.in index 922fb76..30907fd 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -1,5 +1,6 @@ # requirements for `make test` and dependency management attrs>=19.2.0 +black>=24.4.2 pip-tools>=5.5.0 pytest>=4.6 pytest-cov From 32e09cdba5d5f5a2d84c6902b3568c5ba1ea7431 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Fri, 2 Aug 2024 01:36:40 -0500 Subject: [PATCH 09/37] Ignore more pylint issues and re-run black --- trio_websocket/_impl.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index f2dcfc3..f95ed6a 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -72,10 +72,13 @@ def __exit__(self, ty, value, tb): return False if _TRIO_MULTI_ERROR: # pragma: no cover - filtered_exception = trio.MultiError.filter( + filtered_exception = trio.MultiError.filter( # pylint: disable=no-member _ignore_cancel, value ) # pylint: disable=no-member - elif isinstance(value, BaseExceptionGroup): # pylint: disable=possibly-used-before-assignment + elif isinstance( + value, + BaseExceptionGroup # pylint: disable=possibly-used-before-assignment + ): filtered_exception = value.subgroup( lambda exc: not isinstance(exc, trio.Cancelled) ) @@ -92,7 +95,7 @@ async def open_websocket( *, use_ssl: Union[bool, ssl.SSLContext], subprotocols: Optional[Iterable[str]] = None, - extra_headers: Optional[list[tuple[bytes,bytes]]] = None, + extra_headers: Optional[list[tuple[bytes, bytes]]] = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, connect_timeout: float = CONN_TIMEOUT, @@ -861,9 +864,9 @@ def __init__( self._initial_request = None self._path = path self._subprotocol: Optional[str] = None - self._handshake_headers: tuple[tuple[str,str], ...] = tuple() + self._handshake_headers: tuple[tuple[str, str], ...] = tuple() self._reject_status = 0 - self._reject_headers: tuple[tuple[str,str], ...] = tuple() + self._reject_headers: tuple[tuple[str, str], ...] = tuple() self._reject_body = b"" self._send_channel, self._recv_channel = trio.open_memory_channel[ Union[bytes, str] From 6f6da213f62b9341bb308cc858d049ab0b45702d Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Fri, 2 Aug 2024 01:38:39 -0500 Subject: [PATCH 10/37] Re-run `black tests` --- tests/test_connection.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 526ddf7..53f54cf 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -29,6 +29,7 @@ the server to block until the client has sent the closing handshake. In other circumstances """ + from __future__ import annotations from functools import partial, wraps @@ -304,13 +305,20 @@ async def test_client_open_invalid_url(echo_server): async with open_websocket_url("http://foo.com/bar") as conn: pass + async def test_client_open_invalid_ssl(echo_server, nursery): - with pytest.raises(TypeError, match='`use_ssl` argument must be bool or ssl.SSLContext'): + with pytest.raises( + TypeError, match="`use_ssl` argument must be bool or ssl.SSLContext" + ): await connect_websocket(nursery, HOST, echo_server.port, RESOURCE, use_ssl=1) - url = f'ws://{HOST}:{echo_server.port}{RESOURCE}' - with pytest.raises(ValueError, match='^SSL context must be None for ws: URL scheme$' ): - await connect_websocket_url(nursery, url, ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) + url = f"ws://{HOST}:{echo_server.port}{RESOURCE}" + with pytest.raises( + ValueError, match="^SSL context must be None for ws: URL scheme$" + ): + await connect_websocket_url( + nursery, url, ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ) async def test_ascii_encoded_path_is_ok(echo_server): From e43a087dde696de4734bb261ba850dc1dd61409f Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Fri, 2 Aug 2024 01:39:36 -0500 Subject: [PATCH 11/37] Re-run full black again `black trio_websocket/ tests/ autobahn/ examples/` --- trio_websocket/_impl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index f95ed6a..6c012f5 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -76,8 +76,7 @@ def __exit__(self, ty, value, tb): _ignore_cancel, value ) # pylint: disable=no-member elif isinstance( - value, - BaseExceptionGroup # pylint: disable=possibly-used-before-assignment + value, BaseExceptionGroup # pylint: disable=possibly-used-before-assignment ): filtered_exception = value.subgroup( lambda exc: not isinstance(exc, trio.Cancelled) From 77e5779d8813e9eb7c76a6cc43cc810cb008a508 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sun, 20 Oct 2024 15:57:00 +0200 Subject: [PATCH 12/37] fix loss of context/cause on exceptions raised inside open_websocket --- tests/test_connection.py | 37 ++++++++++++++++++++++++++++++++++++- trio_websocket/_impl.py | 23 +++++++++++++++++------ 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 6cccefa..e1292a8 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -513,6 +513,8 @@ async def handler(request): server_ws = await request.accept() await server_ws.ping(b"a") user_cancelled = None + user_cancelled_cause = None + user_cancelled_context = None server = await nursery.start(serve_websocket, handler, HOST, 0, None) with trio.move_on_after(2): @@ -522,8 +524,18 @@ async def handler(request): await trio.sleep_forever() except trio.Cancelled as e: user_cancelled = e + user_cancelled_cause = e.__cause__ + user_cancelled_context = e.__context__ raise - assert exc_info.value is user_cancelled + + # a copy of user_cancelled is reraised + assert exc_info.value is not user_cancelled + # with the same cause + assert exc_info.value.__cause__ is user_cancelled_cause + # the context is the exception group, which contains the original user_cancelled + assert exc_info.value.__context__.exceptions[1] is user_cancelled + assert exc_info.value.__context__.exceptions[1].__cause__ is user_cancelled_cause + assert exc_info.value.__context__.exceptions[1].__context__ is user_cancelled_context def _trio_default_non_strict_exception_groups() -> bool: assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme" @@ -560,6 +572,29 @@ async def handler(request): RaisesGroup(ValueError)))).matches(exc.value) +async def test_user_exception_cause(nursery) -> None: + async def handler(request): + await request.accept() + server = await nursery.start(serve_websocket, handler, HOST, 0, None) + e_context = TypeError("foo") + e_primary = ValueError("bar") + e_cause = RuntimeError("zee") + with pytest.raises(ValueError) as exc_info: + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): + try: + raise e_context + except TypeError: + raise e_primary from e_cause + e = exc_info.value + # a copy is reraised + assert e is not e_primary + assert e.__cause__ is e_cause + + # the nursery-internal group is injected as context + assert isinstance(e.__context__, ExceptionGroup) + assert e.__context__.exceptions[0] is e_primary + assert e.__context__.exceptions[0].__context__ is e_context + @fail_after(1) async def test_reject_handshake(nursery): async def handler(request): diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index a71e0be..e69959b 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import sys from collections import OrderedDict from contextlib import asynccontextmanager @@ -91,6 +92,16 @@ def __exit__(self, ty, value, tb): filtered_exception = _ignore_cancel(value) return filtered_exception is None +def copy_exc(e: BaseException) -> BaseException: + """Copy an exception. + + `copy.copy` fails on `trio.Cancelled`, and on exceptions with a custom `__init__` + that calls `super().__init__()`. It may be the case that this also fails on something. + """ + cls = type(e) + result = cls.__new__(cls) + result.__dict__ = copy.copy(e.__dict__) + return result @asynccontextmanager async def open_websocket( @@ -205,7 +216,7 @@ async def _close_connection(connection: WebSocketConnection) -> None: except _TRIO_EXC_GROUP_TYPE as e: # user_error, or exception bubbling up from _reader_task if len(e.exceptions) == 1: - raise e.exceptions[0] + raise copy_exc(e.exceptions[0]) from e.exceptions[0].__cause__ # contains at most 1 non-cancelled exceptions exception_to_raise: BaseException|None = None @@ -222,21 +233,21 @@ async def _close_connection(connection: WebSocketConnection) -> None: if user_error is not None: # no reason to raise from e, just to include a bunch of extra # cancelleds. - raise user_error # pylint: disable=raise-missing-from + raise copy_exc(user_error) from user_error.__cause__ # multiple internal Cancelled is not possible afaik - raise e.exceptions[0] # pragma: no cover # pylint: disable=raise-missing-from - raise exception_to_raise + raise copy_exc(e.exceptions[0]) from e # pragma: no cover + raise copy_exc(exception_to_raise) from exception_to_raise.__cause__ # if we have any KeyboardInterrupt in the group, make sure to raise it. for sub_exc in e.exceptions: if isinstance(sub_exc, KeyboardInterrupt): - raise sub_exc from e + raise copy_exc(sub_exc) from e # Both user code and internal code raised non-cancelled exceptions. # We "hide" the internal exception(s) in the __cause__ and surface # the user_error. if user_error is not None: - raise user_error from e + raise copy_exc(user_error) from e raise TrioWebsocketInternalError( "The trio-websocket API is not expected to raise multiple exceptions. " From df0f56d06c8462b75aacae4443ffafb8ac001346 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sun, 20 Oct 2024 17:07:14 +0200 Subject: [PATCH 13/37] fix test on non-strict --- tests/test_connection.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index e1292a8..4058409 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -586,14 +586,19 @@ async def handler(request): except TypeError: raise e_primary from e_cause e = exc_info.value - # a copy is reraised - assert e is not e_primary - assert e.__cause__ is e_cause - - # the nursery-internal group is injected as context - assert isinstance(e.__context__, ExceptionGroup) - assert e.__context__.exceptions[0] is e_primary - assert e.__context__.exceptions[0].__context__ is e_context + if _trio_default_non_strict_exception_groups(): + assert e is e_primary + assert e.__cause__ is e_cause + assert e.__context__ is e_context + else: + # a copy is reraised to avoid losing e_context + assert e is not e_primary + assert e.__cause__ is e_cause + + # the nursery-internal group is injected as context + assert isinstance(e.__context__, ExceptionGroup) + assert e.__context__.exceptions[0] is e_primary + assert e.__context__.exceptions[0].__context__ is e_context @fail_after(1) async def test_reject_handshake(nursery): From a4e0562ce46f39ab761aaa00718754a1c1953a7c Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sun, 20 Oct 2024 17:10:34 +0200 Subject: [PATCH 14/37] fix another test fail --- tests/test_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 4058409..61c2848 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -596,7 +596,7 @@ async def handler(request): assert e.__cause__ is e_cause # the nursery-internal group is injected as context - assert isinstance(e.__context__, ExceptionGroup) + assert isinstance(e.__context__, _TRIO_EXC_GROUP_TYPE) assert e.__context__.exceptions[0] is e_primary assert e.__context__.exceptions[0].__context__ is e_context From 5defd0341677029efbfdaf640cd254425f6c1d1f Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 23 Oct 2024 12:19:30 +0200 Subject: [PATCH 15/37] no-copy solution that completely hides the exceptiongroup in most cases --- tests/test_connection.py | 51 ++++++++++++++-------------------- trio_websocket/_impl.py | 60 +++++++++++++++++++++++++++++++--------- 2 files changed, 67 insertions(+), 44 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 61c2848..304390d 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -452,7 +452,6 @@ async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock): Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup """ async def ki_raising_ping_handler(*args, **kwargs) -> None: - print("raising ki") raise KeyboardInterrupt monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler) async def handler(request): @@ -474,11 +473,14 @@ async def handler(request): async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock): """_reader_task._handle_ping_event triggers ValueError. user code also raises exception. - internal exception is in __cause__ exceptiongroup and user exc is delivered + internal exception is in __context__ exceptiongroup and user exc is delivered """ - my_value_error = ValueError() + internal_error = ValueError() + internal_error.__context__ = TypeError() + user_error = NameError() + user_error_context = KeyError() async def raising_ping_event(*args, **kwargs) -> None: - raise my_value_error + raise internal_error monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event) async def handler(request): @@ -486,15 +488,17 @@ async def handler(request): await server_ws.ping(b"a") server = await nursery.start(serve_websocket, handler, HOST, 0, None) - with pytest.raises(trio.TooSlowError) as exc_info: + with pytest.raises(type(user_error)) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): - with trio.fail_after(1) as cs: - cs.shield = True - await trio.sleep(2) + await trio.lowlevel.checkpoint() + user_error.__context__ = user_error_context + raise user_error - e_cause = exc_info.value.__cause__ - assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE) - assert my_value_error in e_cause.exceptions + assert exc_info.value is user_error + e_context = exc_info.value.__context__ + assert isinstance(e_context, BaseExceptionGroup) + assert internal_error in e_context.exceptions + assert user_error_context in e_context.exceptions @fail_after(5) async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock): @@ -528,14 +532,9 @@ async def handler(request): user_cancelled_context = e.__context__ raise - # a copy of user_cancelled is reraised - assert exc_info.value is not user_cancelled - # with the same cause + assert exc_info.value is user_cancelled assert exc_info.value.__cause__ is user_cancelled_cause - # the context is the exception group, which contains the original user_cancelled - assert exc_info.value.__context__.exceptions[1] is user_cancelled - assert exc_info.value.__context__.exceptions[1].__cause__ is user_cancelled_cause - assert exc_info.value.__context__.exceptions[1].__context__ is user_cancelled_context + assert exc_info.value.__context__ is user_cancelled_context def _trio_default_non_strict_exception_groups() -> bool: assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme" @@ -586,19 +585,9 @@ async def handler(request): except TypeError: raise e_primary from e_cause e = exc_info.value - if _trio_default_non_strict_exception_groups(): - assert e is e_primary - assert e.__cause__ is e_cause - assert e.__context__ is e_context - else: - # a copy is reraised to avoid losing e_context - assert e is not e_primary - assert e.__cause__ is e_cause - - # the nursery-internal group is injected as context - assert isinstance(e.__context__, _TRIO_EXC_GROUP_TYPE) - assert e.__context__.exceptions[0] is e_primary - assert e.__context__.exceptions[0].__context__ is e_context + assert e is e_primary + assert e.__cause__ is e_cause + assert e.__context__ is e_context @fail_after(1) async def test_reject_handshake(nursery): diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index e69959b..4ae7338 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -12,7 +12,7 @@ import ssl import struct import urllib.parse -from typing import Iterable, List, Optional, Union +from typing import Iterable, List, NoReturn, Optional, Union import outcome import trio @@ -192,10 +192,29 @@ async def _close_connection(connection: WebSocketConnection) -> None: except trio.TooSlowError: raise DisconnectionTimeout from None + def _raise(exc: BaseException) -> NoReturn: + __tracebackhide__ = True + context = exc.__context__ + try: + raise exc + finally: + exc.__context__ = context + del exc, context + connection: WebSocketConnection|None=None close_result: outcome.Maybe[None] | None = None user_error = None + # Unwrapping exception groups has a lot of pitfalls, one of them stemming from + # the exception we raise also being inside the group that's set as the context. + # This leads to loss of info unless properly handled. + # See https://github.com/python-trio/flake8-async/issues/298 + # We therefore save the exception before raising it, and save our intended context, + # so they can be modified in the `finally`. + exc_to_raise = None + exc_context = None + # by avoiding use of `raise .. from ..` we leave the original __cause__ + try: async with trio.open_nursery() as new_nursery: result = await outcome.acapture(_open_connection, new_nursery) @@ -216,7 +235,7 @@ async def _close_connection(connection: WebSocketConnection) -> None: except _TRIO_EXC_GROUP_TYPE as e: # user_error, or exception bubbling up from _reader_task if len(e.exceptions) == 1: - raise copy_exc(e.exceptions[0]) from e.exceptions[0].__cause__ + _raise(e.exceptions[0]) # contains at most 1 non-cancelled exceptions exception_to_raise: BaseException|None = None @@ -229,25 +248,40 @@ async def _close_connection(connection: WebSocketConnection) -> None: else: if exception_to_raise is None: # all exceptions are cancelled - # prefer raising the one from the user, for traceback reasons + # we reraise the user exception and throw out internal if user_error is not None: - # no reason to raise from e, just to include a bunch of extra - # cancelleds. - raise copy_exc(user_error) from user_error.__cause__ + _raise(user_error) # multiple internal Cancelled is not possible afaik - raise copy_exc(e.exceptions[0]) from e # pragma: no cover - raise copy_exc(exception_to_raise) from exception_to_raise.__cause__ + # but if so we just raise one of them + _raise(e.exceptions[0]) + # raise the non-cancelled exception + _raise(exception_to_raise) - # if we have any KeyboardInterrupt in the group, make sure to raise it. + # if we have any KeyboardInterrupt in the group, raise a new KeyboardInterrupt + # with the group as cause & context for sub_exc in e.exceptions: if isinstance(sub_exc, KeyboardInterrupt): - raise copy_exc(sub_exc) from e + raise KeyboardInterrupt from e # Both user code and internal code raised non-cancelled exceptions. - # We "hide" the internal exception(s) in the __cause__ and surface - # the user_error. + # We set the context to be an exception group containing internal exceptions + # and, if not None, `user_error.__context__` if user_error is not None: - raise copy_exc(user_error) from e + exceptions = [subexc for subexc in e.exceptions if subexc is not user_error] + eg_substr = '' + # there's technically loss of info here, with __suppress_context__=True you + # still have original __context__ available, just not printed. But we delete + # it completely because we can't partially suppress the group + if user_error.__context__ is not None and not user_error.__suppress_context__: + exceptions.append(user_error.__context__) + eg_substr = ' and the context for the user exception' + eg_str = ( + "Both internal and user exceptions encountered. This group contains " + "the internal exception(s)" + eg_substr + "." + ) + user_error.__context__ = BaseExceptionGroup(eg_str, exceptions) + user_error.__suppress_context__ = False + _raise(user_error) raise TrioWebsocketInternalError( "The trio-websocket API is not expected to raise multiple exceptions. " From 1c5bf53130d95d9ea2b8a231ed920940374af6ce Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 23 Oct 2024 12:36:48 +0200 Subject: [PATCH 16/37] fix pylint --- tests/test_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 304390d..326a133 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -496,7 +496,7 @@ async def handler(request): assert exc_info.value is user_error e_context = exc_info.value.__context__ - assert isinstance(e_context, BaseExceptionGroup) + assert isinstance(e_context, BaseExceptionGroup) # pylint: disable=possibly-used-before-assignment assert internal_error in e_context.exceptions assert user_error_context in e_context.exceptions From 16d8d9f1dbfba049f0e00596b22c94e3388f3e09 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 25 Oct 2024 11:59:07 +0200 Subject: [PATCH 17/37] remove unused _copy_exc --- trio_websocket/_impl.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 4ae7338..fa16cfb 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -92,16 +92,6 @@ def __exit__(self, ty, value, tb): filtered_exception = _ignore_cancel(value) return filtered_exception is None -def copy_exc(e: BaseException) -> BaseException: - """Copy an exception. - - `copy.copy` fails on `trio.Cancelled`, and on exceptions with a custom `__init__` - that calls `super().__init__()`. It may be the case that this also fails on something. - """ - cls = type(e) - result = cls.__new__(cls) - result.__dict__ = copy.copy(e.__dict__) - return result @asynccontextmanager async def open_websocket( From 1b0be0583c2195f4d69f67617399b9ed125d7483 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 25 Oct 2024 12:24:11 +0200 Subject: [PATCH 18/37] small cleanups --- trio_websocket/_impl.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index fa16cfb..62c7ff7 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -1,6 +1,5 @@ from __future__ import annotations -import copy import sys from collections import OrderedDict from contextlib import asynccontextmanager @@ -152,14 +151,14 @@ async def open_websocket( # yield to user code. If only one of those raise a non-cancelled exception # we will raise that non-cancelled exception. # If we get multiple cancelled, we raise the user's cancelled. - # If both raise exceptions, we raise the user code's exception with the entire - # exception group as the __cause__. + # If both raise exceptions, we raise the user code's exception with __context__ + # set to a group containing internal exception(s) + any user exception __context__ # If we somehow get multiple exceptions, but no user exception, then we raise # TrioWebsocketInternalError. # If closing the connection fails, then that will be raised as the top # exception in the last `finally`. If we encountered exceptions in user code - # or in reader task then they will be set as the `__cause__`. + # or in reader task then they will be set as the `__context__`. async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection: @@ -183,6 +182,8 @@ async def _close_connection(connection: WebSocketConnection) -> None: raise DisconnectionTimeout from None def _raise(exc: BaseException) -> NoReturn: + """This helper allows re-raising an exception without __context__ being set.""" + # cause does not need special handlng, we simply avoid using `raise .. from ..` __tracebackhide__ = True context = exc.__context__ try: @@ -199,11 +200,7 @@ def _raise(exc: BaseException) -> NoReturn: # the exception we raise also being inside the group that's set as the context. # This leads to loss of info unless properly handled. # See https://github.com/python-trio/flake8-async/issues/298 - # We therefore save the exception before raising it, and save our intended context, - # so they can be modified in the `finally`. - exc_to_raise = None - exc_context = None - # by avoiding use of `raise .. from ..` we leave the original __cause__ + # We therefore avoid having the exceptiongroup included as either cause or context try: async with trio.open_nursery() as new_nursery: @@ -243,7 +240,7 @@ def _raise(exc: BaseException) -> NoReturn: _raise(user_error) # multiple internal Cancelled is not possible afaik # but if so we just raise one of them - _raise(e.exceptions[0]) + _raise(e.exceptions[0]) # pragma: no cover # raise the non-cancelled exception _raise(exception_to_raise) From b8d1fc7fea4cadd70619cfe852711e02a3095c42 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 29 Oct 2024 12:24:52 +0100 Subject: [PATCH 19/37] make exceptions copy- and pickleable --- tests/test_connection.py | 14 ++++++++++++++ trio_websocket/_impl.py | 4 ++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 326a133..0837aa5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -31,6 +31,7 @@ ''' from __future__ import annotations +import copy from functools import partial, wraps import re import ssl @@ -1205,3 +1206,16 @@ async def server(): async with trio.open_nursery() as nursery: nursery.start_soon(server) nursery.start_soon(client) + + +def test_copy_exceptions(): + # test that exceptions are copy- and pickleable + copy.copy(HandshakeError()) + copy.copy(ConnectionTimeout()) + copy.copy(DisconnectionTimeout()) + assert copy.copy(ConnectionClosed("foo")).reason == "foo" + + rej_copy = copy.copy(ConnectionRejected(404, (("a", "b"),), b"c")) + assert rej_copy.status_code == 404 + assert rej_copy.headers == (("a", "b"),) + assert rej_copy.body == b"c" diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 62c7ff7..5f3a9d4 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -608,7 +608,7 @@ def __init__(self, reason): :param reason: :type reason: CloseReason ''' - super().__init__() + super().__init__(reason) self.reason = reason def __repr__(self): @@ -628,7 +628,7 @@ def __init__(self, status_code, headers, body): :param reason: :type reason: CloseReason ''' - super().__init__() + super().__init__(status_code, headers, body) #: a 3 digit HTTP status code self.status_code = status_code #: a tuple of 2-tuples containing header key/value pairs From 3354a33c1ae8883f1677e2450a8d0c6f4b43355b Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Sun, 12 Jan 2025 19:55:45 -0600 Subject: [PATCH 20/37] Re-run black again --- tests/test_connection.py | 41 ++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 7a0bab6..7cce20b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -83,7 +83,7 @@ if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin -WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.'))) +WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split("."))) HOST = "127.0.0.1" RESOURCE = "/resource" @@ -468,17 +468,20 @@ async def handler(request): assert header_value == b"My test header" - - @fail_after(5) async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock): """_reader_task._handle_ping_event triggers KeyboardInterrupt. user code also raises exception. Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup """ + async def ki_raising_ping_handler(*args, **kwargs) -> None: raise KeyboardInterrupt - monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler) + + monkeypatch.setattr( + WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler + ) + async def handler(request): server_ws = await request.accept() await server_ws.ping(b"a") @@ -494,6 +497,7 @@ async def handler(request): assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE) assert any(isinstance(e, trio.TooSlowError) for e in e_cause.exceptions) + @fail_after(5) async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock): """_reader_task._handle_ping_event triggers ValueError. @@ -504,10 +508,12 @@ async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock) internal_error.__context__ = TypeError() user_error = NameError() user_error_context = KeyError() + async def raising_ping_event(*args, **kwargs) -> None: raise internal_error monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event) + async def handler(request): server_ws = await request.accept() await server_ws.ping(b"a") @@ -521,26 +527,30 @@ async def handler(request): assert exc_info.value is user_error e_context = exc_info.value.__context__ - assert isinstance(e_context, BaseExceptionGroup) # pylint: disable=possibly-used-before-assignment + assert isinstance( + e_context, BaseExceptionGroup + ) # pylint: disable=possibly-used-before-assignment assert internal_error in e_context.exceptions assert user_error_context in e_context.exceptions + @fail_after(5) async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock): """Both user code and _reader_task raise Cancellation. Check that open_websocket reraises the one from user code for traceback reasons. """ - async def sleeping_ping_event(*args, **kwargs) -> None: await trio.sleep_forever() # We monkeypatch WebSocketConnection._handle_ping_event to ensure it will actually # raise Cancelled upon being cancelled. For some reason it doesn't otherwise. monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", sleeping_ping_event) + async def handler(request): server_ws = await request.accept() await server_ws.ping(b"a") + user_cancelled = None user_cancelled_cause = None user_cancelled_context = None @@ -561,10 +571,14 @@ async def handler(request): assert exc_info.value.__cause__ is user_cancelled_cause assert exc_info.value.__context__ is user_cancelled_context + def _trio_default_non_strict_exception_groups() -> bool: - assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme" + assert re.match( + r"^0\.\d\d\.", trio.__version__ + ), "unexpected trio versioning scheme" return int(trio.__version__[2:4]) < 25 + @fail_after(1) async def test_handshake_exception_before_accept() -> None: """In #107, a request handler that throws an exception before finishing the @@ -575,7 +589,9 @@ async def handler(request): raise ValueError() # pylint fails to resolve that BaseExceptionGroup will always be available - with pytest.raises((BaseExceptionGroup, ValueError)) as exc: # pylint: disable=possibly-used-before-assignment + with pytest.raises( + (BaseExceptionGroup, ValueError) + ) as exc: # pylint: disable=possibly-used-before-assignment async with trio.open_nursery() as nursery: server = await nursery.start(serve_websocket, handler, HOST, 0, None) async with open_websocket( @@ -591,15 +607,15 @@ async def handler(request): # 2. WebSocketServer.run # 3. trio.serve_listeners # 4. WebSocketServer._handle_connection - assert RaisesGroup( - RaisesGroup( - RaisesGroup( - RaisesGroup(ValueError)))).matches(exc.value) + assert RaisesGroup(RaisesGroup(RaisesGroup(RaisesGroup(ValueError)))).matches( + exc.value + ) async def test_user_exception_cause(nursery) -> None: async def handler(request): await request.accept() + server = await nursery.start(serve_websocket, handler, HOST, 0, None) e_context = TypeError("foo") e_primary = ValueError("bar") @@ -615,6 +631,7 @@ async def handler(request): assert e.__cause__ is e_cause assert e.__context__ is e_context + @fail_after(1) async def test_reject_handshake(nursery): async def handler(request): From 00fc838aa276677349eaebe23a26481ceb2500f1 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Sun, 12 Jan 2025 22:24:18 -0600 Subject: [PATCH 21/37] Add/complete type annotations --- mypy.ini | 25 +++ setup.py | 1 + trio_websocket/_impl.py | 397 +++++++++++++++++++++++++--------------- trio_websocket/py.typed | 0 4 files changed, 278 insertions(+), 145 deletions(-) create mode 100644 mypy.ini create mode 100644 trio_websocket/py.typed diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..92e5727 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,25 @@ +[mypy] +files = trio_websocket +check_untyped_defs = true +disallow_any_decorated = true +disallow_any_generics = true +disallow_any_unimported = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +ignore_missing_imports = true +local_partial_types = true +no_implicit_optional = true +no_implicit_reexport = true +show_column_numbers = true +show_error_codes = true +show_traceback = true +strict = true +strict_equality = true +warn_redundant_casts = true +warn_return_any = true +warn_unreachable = true +warn_unused_configs = true +warn_unused_ignores = true diff --git a/setup.py b/setup.py index 17a21f9..b38bb70 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ python_requires=">=3.8", keywords='websocket client server trio', packages=find_packages(exclude=['docs', 'examples', 'tests']), + package_data={"trio-websocket": ["py.typed"]}, install_requires=[ 'exceptiongroup; python_version<"3.11"', 'trio>=0.11', diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 5f3a9d4..29ac08c 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -2,7 +2,7 @@ import sys from collections import OrderedDict -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, AbstractAsyncContextManager from functools import partial from ipaddress import ip_address import itertools @@ -11,7 +11,7 @@ import ssl import struct import urllib.parse -from typing import Iterable, List, NoReturn, Optional, Union +from typing import Any, List, NoReturn, Optional, Union, TypeVar, TYPE_CHECKING, Generic, cast import outcome import trio @@ -36,7 +36,11 @@ # pylint doesn't care about the version_info check, so need to ignore the warning from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin -_IS_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.')[:2])) < (0, 22) +if TYPE_CHECKING: + from types import TracebackType + from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Coroutine, Sequence + +_IS_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.')[:2])) < (0, 22) # type: ignore[attr-defined] if _IS_TRIO_MULTI_ERROR: _TRIO_EXC_GROUP_TYPE = trio.MultiError # type: ignore[attr-defined] # pylint: disable=no-member @@ -49,6 +53,9 @@ RECEIVE_BYTES = 4 * 2 ** 10 # 4 KiB logger = logging.getLogger('trio-websocket') +T = TypeVar("T") +E = TypeVar("E", bound=BaseException) + class TrioWebsocketInternalError(Exception): """Raised as a fallback when open_websocket is unable to unwind an exceptiongroup @@ -57,7 +64,7 @@ class TrioWebsocketInternalError(Exception): """ -def _ignore_cancel(exc): +def _ignore_cancel(exc: E) -> E | None: return None if isinstance(exc, trio.Cancelled) else exc @@ -73,18 +80,23 @@ class _preserve_current_exception: """ __slots__ = ("_armed",) - def __init__(self): + def __init__(self) -> None: self._armed = False - def __enter__(self): + def __enter__(self) -> None: self._armed = sys.exc_info()[1] is not None - def __exit__(self, ty, value, tb): + def __exit__( + self, + ty: type[BaseException] | None, + value: BaseException | None, + tb: TracebackType | None, + ) -> bool: if value is None or not self._armed: return False if _IS_TRIO_MULTI_ERROR: # pragma: no cover - filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # pylint: disable=no-member + filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # type: ignore[attr-defined] # pylint: disable=no-member elif isinstance(value, BaseExceptionGroup): # pylint: disable=possibly-used-before-assignment filtered_exception = value.subgroup(lambda exc: not isinstance(exc, trio.Cancelled)) else: @@ -94,18 +106,18 @@ def __exit__(self, ty, value, tb): @asynccontextmanager async def open_websocket( - host: str, - port: int, - resource: str, - *, - use_ssl: Union[bool, ssl.SSLContext], - subprotocols: Optional[Iterable[str]] = None, - extra_headers: Optional[list[tuple[bytes,bytes]]] = None, - message_queue_size: int = MESSAGE_QUEUE_SIZE, - max_message_size: int = MAX_MESSAGE_SIZE, - connect_timeout: float = CONN_TIMEOUT, - disconnect_timeout: float = CONN_TIMEOUT - ): + host: str, + port: int, + resource: str, + *, + use_ssl: Union[bool, ssl.SSLContext], + subprotocols: Optional[Iterable[str]] = None, + extra_headers: Optional[list[tuple[bytes,bytes]]] = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, + connect_timeout: float = CONN_TIMEOUT, + disconnect_timeout: float = CONN_TIMEOUT +) -> AsyncGenerator[WebSocketConnection, None]: ''' Open a WebSocket client connection to a host. @@ -286,10 +298,18 @@ def _raise(exc: BaseException) -> NoReturn: result.unwrap() -async def connect_websocket(nursery, host, port, resource, *, use_ssl, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE - ) -> WebSocketConnection: +async def connect_websocket( + nursery: trio.Nursery, + host: str, + port: int, + resource: str, + *, + use_ssl: bool | ssl.SSLContext, + subprotocols: Iterable[str] | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, +) -> WebSocketConnection: ''' Return an open WebSocket client connection to a host. @@ -352,10 +372,17 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, return connection -def open_websocket_url(url, ssl_context=None, *, subprotocols=None, - extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, - connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): +def open_websocket_url( + url: str, + ssl_context: ssl.SSLContext | None = None, + *, + subprotocols: Iterable[str] | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, + connect_timeout: float = CONN_TIMEOUT, + disconnect_timeout: float = CONN_TIMEOUT, +) -> AbstractAsyncContextManager[WebSocketConnection]: ''' Open a WebSocket client connection to a URL. @@ -386,17 +413,24 @@ def open_websocket_url(url, ssl_context=None, *, subprotocols=None, client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`), or server rejection (:exc:`ConnectionRejected`) during handshakes. ''' - host, port, resource, ssl_context = _url_to_host(url, ssl_context) - return open_websocket(host, port, resource, use_ssl=ssl_context, + host, port, resource, return_ssl_context = _url_to_host(url, ssl_context) + return open_websocket(host, port, resource, use_ssl=return_ssl_context, subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, max_message_size=max_message_size, connect_timeout=connect_timeout, disconnect_timeout=disconnect_timeout) -async def connect_websocket_url(nursery, url, ssl_context=None, *, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): +async def connect_websocket_url( + nursery: trio.Nursery, + url: str, + ssl_context: ssl.SSLContext | None = None, + *, + subprotocols: Iterable[str] | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, +) -> WebSocketConnection: ''' Return an open WebSocket client connection to a URL. @@ -423,14 +457,14 @@ async def connect_websocket_url(nursery, url, ssl_context=None, *, then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection ''' - host, port, resource, ssl_context = _url_to_host(url, ssl_context) + host, port, resource, return_ssl_context = _url_to_host(url, ssl_context) return await connect_websocket(nursery, host, port, resource, - use_ssl=ssl_context, subprotocols=subprotocols, + use_ssl=return_ssl_context, subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, max_message_size=max_message_size) -def _url_to_host(url, ssl_context): +def _url_to_host(url: str, ssl_context: ssl.SSLContext | None) -> tuple[str, int, str, ssl.SSLContext | bool]: ''' Convert a WebSocket URL to a (host,port,resource) tuple. @@ -446,11 +480,16 @@ def _url_to_host(url, ssl_context): parts = urllib.parse.urlsplit(url) if parts.scheme not in ('ws', 'wss'): raise ValueError('WebSocket URL scheme must be "ws:" or "wss:"') + return_ssl_context: ssl.SSLContext | bool if ssl_context is None: - ssl_context = parts.scheme == 'wss' + return_ssl_context = parts.scheme == 'wss' elif parts.scheme == 'ws': raise ValueError('SSL context must be None for ws: URL scheme') + else: + return_ssl_context = ssl_context host = parts.hostname + if host is None: + raise ValueError('URL host must not be None') if parts.port is not None: port = parts.port else: @@ -463,12 +502,20 @@ def _url_to_host(url, ssl_context): path_qs = '/' if '?' in url: path_qs += '?' + parts.query - return host, port, path_qs, ssl_context - - -async def wrap_client_stream(nursery, stream, host, resource, *, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): + return host, port, path_qs, return_ssl_context + + +async def wrap_client_stream( + nursery: trio.Nursery, + stream: trio.SocketStream | trio.SSLStream[trio.SocketStream], + host: str, + resource: str, + *, + subprotocols: Iterable[str] | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, +) -> WebSocketConnection: ''' Wrap an arbitrary stream in a WebSocket connection. @@ -505,8 +552,12 @@ async def wrap_client_stream(nursery, stream, host, resource, *, return connection -async def wrap_server_stream(nursery, stream, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): +async def wrap_server_stream( + nursery: trio.Nursery, + stream: trio.abc.Stream, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, +) -> WebSocketRequest: ''' Wrap an arbitrary stream in a server-side WebSocket. @@ -523,7 +574,8 @@ async def wrap_server_stream(nursery, stream, :type stream: trio.abc.Stream :rtype: WebSocketRequest ''' - connection = WebSocketConnection(stream, + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.SERVER), message_queue_size=message_queue_size, max_message_size=max_message_size) @@ -532,10 +584,19 @@ async def wrap_server_stream(nursery, stream, return request -async def serve_websocket(handler, host, port, ssl_context, *, - handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, - disconnect_timeout=CONN_TIMEOUT, task_status=trio.TASK_STATUS_IGNORED): +async def serve_websocket( + handler: Callable[[WebSocketRequest], Awaitable[None]], + host: str | bytes | None, + port: int, + ssl_context: ssl.SSLContext | None, + *, + handler_nursery: trio.Nursery | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, + connect_timeout: float = CONN_TIMEOUT, + disconnect_timeout: float = CONN_TIMEOUT, + task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED, +) -> NoReturn: ''' Serve a WebSocket over TCP. @@ -571,12 +632,13 @@ async def serve_websocket(handler, host, port, ssl_context, *, :param task_status: Part of Trio nursery start protocol. :returns: This function runs until cancelled. ''' + open_tcp_listeners: partial[Coroutine[Any, Any, list[trio.SocketListener]]] | partial[Coroutine[Any, Any, list[trio.SSLListener[trio.SocketStream]]]] if ssl_context is None: open_tcp_listeners = partial(trio.open_tcp_listeners, port, host=host) else: open_tcp_listeners = partial(trio.open_ssl_over_tcp_listeners, port, ssl_context, host=host, https_compatible=True) - listeners = await open_tcp_listeners() + listeners: list[trio.SSLListener[trio.SocketStream]] | list[trio.SocketListener] = await open_tcp_listeners() server = WebSocketServer(handler, listeners, handler_nursery=handler_nursery, message_queue_size=message_queue_size, max_message_size=max_message_size, connect_timeout=connect_timeout, @@ -601,7 +663,7 @@ class ConnectionClosed(Exception): A WebSocket operation cannot be completed because the connection is closed or in the process of closing. ''' - def __init__(self, reason): + def __init__(self, reason: CloseReason | None) -> None: ''' Constructor. @@ -611,7 +673,7 @@ def __init__(self, reason): super().__init__(reason) self.reason = reason - def __repr__(self): + def __repr__(self) -> str: ''' Return representation. ''' return f'{self.__class__.__name__}<{self.reason}>' @@ -621,7 +683,7 @@ class ConnectionRejected(HandshakeError): A WebSocket connection could not be established because the server rejected the connection attempt. ''' - def __init__(self, status_code, headers, body): + def __init__(self, status_code: int, headers: tuple[tuple[bytes, bytes], ...], body: bytes | None): ''' Constructor. @@ -636,14 +698,14 @@ def __init__(self, status_code, headers, body): #: an optional ``bytes`` response body self.body = body - def __repr__(self): + def __repr__(self) -> str: ''' Return representation. ''' return f'{self.__class__.__name__}' class CloseReason: ''' Contains information about why a WebSocket was closed. ''' - def __init__(self, code, reason): + def __init__(self, code: int, reason: str | None) -> None: ''' Constructor. @@ -665,34 +727,34 @@ def __init__(self, code, reason): self._reason = reason @property - def code(self): + def code(self) -> int: ''' (Read-only) The numeric close code. ''' return self._code @property - def name(self): + def name(self) -> str: ''' (Read-only) The human-readable close code. ''' return self._name @property - def reason(self): + def reason(self) -> str | None: ''' (Read-only) An arbitrary reason string. ''' return self._reason - def __repr__(self): + def __repr__(self) -> str: ''' Show close code, name, and reason. ''' return f'{self.__class__.__name__}' \ f'' -class Future: +class Future(Generic[T]): ''' Represents a value that will be available in the future. ''' - def __init__(self): + def __init__(self) -> None: ''' Constructor. ''' - self._value = None + self._value: T | None = None self._value_event = trio.Event() - def set_value(self, value): + def set_value(self, value: T) -> None: ''' Set a value, which will notify any waiters. @@ -701,14 +763,14 @@ def set_value(self, value): self._value = value self._value_event.set() - async def wait_value(self): + async def wait_value(self) -> T: ''' Wait for this future to have a value, then return it. :returns: The value set by ``set_value()``. ''' await self._value_event.wait() - return self._value + return cast(T, self._value) class WebSocketRequest: @@ -718,7 +780,7 @@ class WebSocketRequest: The server may modify the handshake or leave it as is. The server should call ``accept()`` to finish the handshake and obtain a connection object. ''' - def __init__(self, connection, event): + def __init__(self, connection: WebSocketConnection, event: wsproto.events.Request) -> None: ''' Constructor. @@ -729,7 +791,7 @@ def __init__(self, connection, event): self._event = event @property - def headers(self): + def headers(self) -> list[tuple[bytes, bytes]]: ''' HTTP headers represented as a list of (name, value) pairs. @@ -738,7 +800,7 @@ def headers(self): return self._event.extra_headers @property - def path(self): + def path(self) -> str: ''' The requested URL path. @@ -747,7 +809,7 @@ def path(self): return self._event.target @property - def proposed_subprotocols(self): + def proposed_subprotocols(self) -> tuple[str, ...]: ''' A tuple of protocols proposed by the client. @@ -756,7 +818,7 @@ def proposed_subprotocols(self): return tuple(self._event.subprotocols) @property - def local(self): + def local(self) -> Endpoint | str: ''' The connection's local endpoint. @@ -765,7 +827,7 @@ def local(self): return self._connection.local @property - def remote(self): + def remote(self) -> Endpoint | str: ''' The connection's remote endpoint. @@ -773,7 +835,12 @@ def remote(self): ''' return self._connection.remote - async def accept(self, *, subprotocol=None, extra_headers=None): + async def accept( + self, + *, + subprotocol: str | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + ) -> WebSocketConnection: ''' Accept the request and return a connection object. @@ -789,7 +856,13 @@ async def accept(self, *, subprotocol=None, extra_headers=None): await self._connection._accept(self._event, subprotocol, extra_headers) return self._connection - async def reject(self, status_code, *, extra_headers=None, body=None): + async def reject( + self, + status_code: int, + *, + extra_headers: list[tuple[bytes, bytes]] | None = None, + body: bytes | None = None, + ) -> None: ''' Reject the handshake. @@ -807,7 +880,11 @@ async def reject(self, status_code, *, extra_headers=None, body=None): await self._connection._reject(status_code, extra_headers, body) -def _get_stream_endpoint(stream, *, local): +def _get_stream_endpoint( + stream: trio.abc.Stream, + *, + local: bool, +) -> Endpoint | str: ''' Construct an endpoint from a stream. @@ -823,6 +900,7 @@ def _get_stream_endpoint(stream, *, local): elif isinstance(stream, trio.SSLStream): socket = stream.transport_stream.socket is_ssl = True + endpoint: Endpoint | str if socket: addr, port, *_ = socket.getsockname() if local else socket.getpeername() endpoint = Endpoint(addr, port, is_ssl) @@ -837,16 +915,17 @@ class WebSocketConnection(trio.abc.AsyncResource): CONNECTION_ID = itertools.count() def __init__( - self, - stream: trio.SocketStream | trio.SSLStream[trio.SocketStream], - ws_connection: wsproto.WSConnection, - *, - host=None, - path=None, - client_subprotocols=None, client_extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE - ): + self, + stream: trio.abc.Stream, + ws_connection: wsproto.WSConnection, + *, + host: str | None = None, + path: str | None = None, + client_subprotocols: Iterable[str] | None = None, + client_extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE + ) -> None: ''' Constructor. @@ -886,16 +965,18 @@ def __init__( self._max_message_size = max_message_size self._reader_running = True if ws_connection.client: + assert host is not None + assert path is not None self._initial_request: Optional[Request] = Request(host=host, target=path, - subprotocols=client_subprotocols, + subprotocols=list(client_subprotocols or ()), extra_headers=client_extra_headers or []) else: self._initial_request = None self._path = path self._subprotocol: Optional[str] = None - self._handshake_headers: tuple[tuple[str,str], ...] = tuple() + self._handshake_headers: tuple[tuple[bytes, bytes], ...] = () self._reject_status = 0 - self._reject_headers: tuple[tuple[str,str], ...] = tuple() + self._reject_headers: tuple[tuple[bytes, bytes], ...] = () self._reject_body = b'' self._send_channel, self._recv_channel = trio.open_memory_channel[ Union[bytes, str] @@ -903,7 +984,7 @@ def __init__( self._pings: OrderedDict[bytes, trio.Event] = OrderedDict() # Set when the server has received a connection request event. This # future is never set on client connections. - self._connection_proposal = Future() + self._connection_proposal: Future[WebSocketRequest] | None = Future[WebSocketRequest]() # Set once the WebSocket open handshake takes place, i.e. # ConnectionRequested for server or ConnectedEstablished for client. self._open_handshake = trio.Event() @@ -915,7 +996,7 @@ def __init__( self._for_testing_peer_closed_connection = trio.Event() @property - def closed(self): + def closed(self) -> CloseReason | None: ''' (Read-only) The reason why the connection was or is being closed, else ``None``. @@ -925,17 +1006,17 @@ def closed(self): return self._close_reason @property - def is_client(self): + def is_client(self) -> bool: ''' (Read-only) Is this a client instance? ''' return self._wsproto.client @property - def is_server(self): + def is_server(self) -> bool: ''' (Read-only) Is this a server instance? ''' return not self._wsproto.client @property - def local(self): + def local(self) -> Endpoint | str: ''' The local endpoint of the connection. @@ -944,7 +1025,7 @@ def local(self): return _get_stream_endpoint(self._stream, local=True) @property - def remote(self): + def remote(self) -> Endpoint | str: ''' The remote endpoint of the connection. @@ -953,17 +1034,17 @@ def remote(self): return _get_stream_endpoint(self._stream, local=False) @property - def path(self): + def path(self) -> str | None: ''' The requested URL path. For clients, this is set when the connection is instantiated. For servers, it is set after the handshake completes. - :rtype: str + :rtype: str or None ''' return self._path @property - def subprotocol(self): + def subprotocol(self) -> str | None: ''' (Read-only) The negotiated subprotocol, or ``None`` if there is no subprotocol. @@ -975,7 +1056,7 @@ def subprotocol(self): return self._subprotocol @property - def handshake_headers(self): + def handshake_headers(self) -> tuple[tuple[bytes, bytes], ...]: ''' The HTTP headers that were sent by the remote during the handshake, stored as 2-tuples containing key/value pairs. Header keys are always @@ -985,7 +1066,7 @@ def handshake_headers(self): ''' return self._handshake_headers - async def aclose(self, code=1000, reason=None): # pylint: disable=arguments-differ + async def aclose(self, code: int = 1000, reason: str | None = None) -> None: # pylint: disable=arguments-differ ''' Close the WebSocket connection. @@ -1003,7 +1084,7 @@ async def aclose(self, code=1000, reason=None): # pylint: disable=arguments-dif with _preserve_current_exception(): await self._aclose(code, reason) - async def _aclose(self, code, reason): + async def _aclose(self, code: int, reason: str | None) -> None: if self._close_reason: # Per AsyncResource interface, calling aclose() on a closed resource # should succeed. @@ -1029,7 +1110,7 @@ async def _aclose(self, code, reason): # stream is closed. await self._close_stream() - async def get_message(self): + async def get_message(self) -> str | bytes: ''' Receive the next WebSocket message. @@ -1052,7 +1133,7 @@ async def get_message(self): raise ConnectionClosed(self._close_reason) from None return message - async def ping(self, payload: bytes|None=None): + async def ping(self, payload: bytes | None = None) -> None: ''' Send WebSocket ping to remote endpoint and wait for a correspoding pong. @@ -1083,7 +1164,7 @@ async def ping(self, payload: bytes|None=None): await self._send(Ping(payload=payload)) await event.wait() - async def pong(self, payload=None): + async def pong(self, payload: bytes | None = None) -> None: ''' Send an unsolicted pong. @@ -1094,9 +1175,9 @@ async def pong(self, payload=None): ''' if self._close_reason: raise ConnectionClosed(self._close_reason) - await self._send(Pong(payload=payload)) + await self._send(Pong(payload=payload or b'')) - async def send_message(self, message): + async def send_message(self, message: str | bytes) -> None: ''' Send a WebSocket message. @@ -1106,6 +1187,7 @@ async def send_message(self, message): ''' if self._close_reason: raise ConnectionClosed(self._close_reason) + event: TextMessage | BytesMessage if isinstance(message, str): event = TextMessage(data=message) elif isinstance(message, bytes): @@ -1114,12 +1196,17 @@ async def send_message(self, message): raise ValueError('message must be str or bytes') await self._send(event) - def __str__(self): + def __str__(self) -> str: ''' Connection ID and type. ''' type_ = 'client' if self.is_client else 'server' return f'{type_}-{self._id}' - async def _accept(self, request, subprotocol, extra_headers): + async def _accept( + self, + request: Request, + subprotocol: str | None, + extra_headers: list[tuple[bytes, bytes]], + ) -> None: ''' Accept the handshake. @@ -1137,7 +1224,12 @@ async def _accept(self, request, subprotocol, extra_headers): extra_headers=extra_headers)) self._open_handshake.set() - async def _reject(self, status_code, headers, body): + async def _reject( + self, + status_code: int, + headers: list[tuple[bytes, bytes]], + body: bytes, + ) -> None: ''' Reject the handshake. @@ -1149,7 +1241,7 @@ async def _reject(self, status_code, headers, body): :param bytes body: An optional response body. ''' if body: - headers.append(('Content-length', str(len(body)).encode('ascii'))) + headers.append((b'Content-length', str(len(body)).encode('ascii'))) reject_conn = RejectConnection(status_code=status_code, headers=headers, has_body=bool(body)) await self._send(reject_conn) @@ -1159,7 +1251,7 @@ async def _reject(self, status_code, headers, body): self._close_reason = CloseReason(1006, 'Rejected WebSocket handshake') self._close_handshake.set() - async def _abort_web_socket(self): + async def _abort_web_socket(self) -> None: ''' If a stream is closed outside of this class, e.g. due to network conditions or because some other code closed our stream object, then we @@ -1176,7 +1268,7 @@ async def _abort_web_socket(self): # (e.g. self.aclose()) to resume. self._close_handshake.set() - async def _close_stream(self): + async def _close_stream(self) -> None: ''' Close the TCP connection. ''' self._reader_running = False try: @@ -1186,7 +1278,7 @@ async def _close_stream(self): # This means the TCP connection is already dead. pass - async def _close_web_socket(self, code, reason=None): + async def _close_web_socket(self, code: int, reason: str | None = None) -> None: ''' Mark the WebSocket as closed. Close the message channel so that if any tasks are suspended in get_message(), they will wake up with a @@ -1197,7 +1289,7 @@ async def _close_web_socket(self, code, reason=None): logger.debug('%s websocket closed %r', self, exc) await self._send_channel.aclose() - async def _get_request(self): + async def _get_request(self) -> WebSocketRequest: ''' Return a proposal for a WebSocket handshake. @@ -1215,7 +1307,7 @@ async def _get_request(self): self._connection_proposal = None return proposal - async def _handle_request_event(self, event): + async def _handle_request_event(self, event: wsproto.events.Request) -> None: ''' Handle a connection request. @@ -1225,9 +1317,10 @@ async def _handle_request_event(self, event): :param event: ''' proposal = WebSocketRequest(self, event) + assert self._connection_proposal is not None self._connection_proposal.set_value(proposal) - async def _handle_accept_connection_event(self, event): + async def _handle_accept_connection_event(self, event: wsproto.events.AcceptConnection) -> None: ''' Handle an AcceptConnection event. @@ -1237,7 +1330,7 @@ async def _handle_accept_connection_event(self, event): self._handshake_headers = tuple(event.extra_headers) self._open_handshake.set() - async def _handle_reject_connection_event(self, event): + async def _handle_reject_connection_event(self, event: wsproto.events.RejectConnection) -> None: ''' Handle a RejectConnection event. @@ -1249,7 +1342,7 @@ async def _handle_reject_connection_event(self, event): raise ConnectionRejected(self._reject_status, self._reject_headers, body=None) - async def _handle_reject_data_event(self, event): + async def _handle_reject_data_event(self, event: wsproto.events.RejectData) -> None: ''' Handle a RejectData event. @@ -1260,7 +1353,7 @@ async def _handle_reject_data_event(self, event): raise ConnectionRejected(self._reject_status, self._reject_headers, body=self._reject_body) - async def _handle_close_connection_event(self, event): + async def _handle_close_connection_event(self, event: wsproto.events.CloseConnection) -> None: ''' Handle a close event. @@ -1281,7 +1374,7 @@ async def _handle_close_connection_event(self, event): if self.is_server: await self._close_stream() - async def _handle_message_event(self, event): + async def _handle_message_event(self, event: wsproto.events.BytesMessage | wsproto.events.TextMessage) -> None: ''' Handle a message event. @@ -1299,8 +1392,12 @@ async def _handle_message_event(self, event): await self._recv_channel.aclose() self._reader_running = False elif event.message_finished: - msg = (b'' if isinstance(event, BytesMessage) else '') \ - .join(self._message_parts) + msg: str | bytes + # Type checker does not understand `_message_parts` + if isinstance(event, BytesMessage): + msg = b''.join(self._message_parts) + else: + msg = ''.join(self._message_parts) self._message_size = 0 self._message_parts = [] try: @@ -1311,7 +1408,7 @@ async def _handle_message_event(self, event): # and there's no useful cleanup that we can do here. pass - async def _handle_ping_event(self, event): + async def _handle_ping_event(self, event: wsproto.events.Ping) -> None: ''' Handle a PingReceived event. @@ -1323,7 +1420,7 @@ async def _handle_ping_event(self, event): logger.debug('%s ping %r', self, event.payload) await self._send(event.response()) - async def _handle_pong_event(self, event): + async def _handle_pong_event(self, event: wsproto.events.Pong) -> None: ''' Handle a PongReceived event. @@ -1339,20 +1436,20 @@ async def _handle_pong_event(self, event): ''' payload = bytes(event.payload) try: - event = self._pings[payload] + ping_event = self._pings[payload] except KeyError: # We received a pong that doesn't match any in-flight pongs. Nothing # we can do with it, so ignore it. return while self._pings: - key, event = self._pings.popitem(0) + key, ping_event = self._pings.popitem(False) skipped = ' [skipped] ' if payload != key else ' ' logger.debug('%s pong%s%r', self, skipped, key) - event.set() + ping_event.set() if payload == key: break - async def _reader_task(self): + async def _reader_task(self) -> None: ''' A background task that reads network data and generates events. ''' handlers = { AcceptConnection: self._handle_accept_connection_event, @@ -1382,7 +1479,7 @@ async def _reader_task(self): handler = handlers[event_type] logger.debug('%s received event: %s', self, event_type) - await handler(event) + await handler(event) # type: ignore[operator] except KeyError: logger.warning('%s received unknown event type: "%s"', self, event_type) @@ -1416,7 +1513,7 @@ async def _reader_task(self): logger.debug('%s reader task finished', self) - async def _send(self, event): + async def _send(self, event: wsproto.events.Event) -> None: ''' Send an event to the remote WebSocket. @@ -1433,12 +1530,13 @@ async def _send(self, event): await self._stream.send_all(data) except (trio.BrokenResourceError, trio.ClosedResourceError): await self._abort_web_socket() + assert self._close_reason is not None raise ConnectionClosed(self._close_reason) from None class Endpoint: ''' Represents a connection endpoint. ''' - def __init__(self, address, port, is_ssl): + def __init__(self, address: str | int, port: int, is_ssl: bool) -> None: #: IP address :class:`ipaddress.ip_address` self.address = ip_address(address) #: TCP port @@ -1447,7 +1545,7 @@ def __init__(self, address, port, is_ssl): self.is_ssl = is_ssl @property - def url(self): + def url(self) -> str: ''' Return a URL representation of a TCP endpoint, e.g. ``ws://127.0.0.1:80``. ''' scheme = 'wss' if self.is_ssl else 'ws' @@ -1460,7 +1558,7 @@ def url(self): return f'{scheme}://{self.address}{port_str}' return f'{scheme}://[{self.address}]{port_str}' - def __repr__(self): + def __repr__(self) -> str: ''' Return endpoint info as string. ''' return f'Endpoint(address="{self.address}", port={self.port}, is_ssl={self.is_ssl})' @@ -1474,10 +1572,17 @@ class WebSocketServer: instance and starts some background tasks, ''' - def __init__(self, handler, listeners, *, handler_nursery=None, - message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, - disconnect_timeout=CONN_TIMEOUT): + def __init__( + self, + handler: Callable[[WebSocketRequest], Awaitable[None]], + listeners: Sequence[trio.SSLListener[trio.SocketStream] | trio.SocketListener], + *, + handler_nursery: trio.Nursery | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, + connect_timeout: float = CONN_TIMEOUT, + disconnect_timeout: float = CONN_TIMEOUT, + ) -> None: ''' Constructor. @@ -1509,7 +1614,7 @@ def __init__(self, handler, listeners, *, handler_nursery=None, self._disconnect_timeout = disconnect_timeout @property - def port(self): + def port(self) -> int: """Returns the requested or kernel-assigned port number. In the case of kernel-assigned port (requested with port=0 in the @@ -1522,15 +1627,15 @@ def port(self): """ if len(self._listeners) > 1: raise RuntimeError('Cannot get port because this server has' - ' more than 1 listeners.') + ' more than 1 listener.') listener = self.listeners[0] try: - return listener.port + return listener.port # type: ignore[union-attr] except AttributeError: raise RuntimeError(f'This socket does not have a port: {repr(listener)}') from None @property - def listeners(self): + def listeners(self) -> list[Endpoint | str]: ''' Return a list of listener metadata. Each TCP listener is represented as an ``Endpoint`` instance. Other listener types are represented by their @@ -1539,13 +1644,15 @@ def listeners(self): :returns: Listeners :rtype list[Endpoint or str]: ''' - listeners = [] + listeners: list[Endpoint | str] = [] for listener in self._listeners: socket, is_ssl = None, False if isinstance(listener, trio.SocketListener): socket = listener.socket elif isinstance(listener, trio.SSLListener): - socket = listener.transport_listener.socket + internal_listener = listener.transport_listener + assert isinstance(internal_listener, trio.SocketListener) + socket = internal_listener.socket is_ssl = True if socket: sockname = socket.getsockname() @@ -1554,7 +1661,7 @@ def listeners(self): listeners.append(repr(listener)) return listeners - async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): + async def run(self, *, task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED) -> NoReturn: ''' Start serving incoming connections requests. @@ -1567,7 +1674,7 @@ async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): ''' async with trio.open_nursery() as nursery: serve_listeners = partial(trio.serve_listeners, - self._handle_connection, self._listeners, + self._handle_connection, list(self._listeners), handler_nursery=self._handler_nursery) await nursery.start(serve_listeners) logger.debug('Listening on %s', @@ -1575,7 +1682,7 @@ async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): task_status.started(self) await trio.sleep_forever() - async def _handle_connection(self, stream): + async def _handle_connection(self, stream: trio.abc.Stream) -> None: ''' Handle an incoming connection by spawning a connection background task and a handler task inside a new nursery. diff --git a/trio_websocket/py.typed b/trio_websocket/py.typed new file mode 100644 index 0000000..e69de29 From c49321fe9247c8685580a5a9ac1a0b3400fd87d9 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Sun, 12 Jan 2025 22:36:31 -0600 Subject: [PATCH 22/37] Fix line-too-long errors --- trio_websocket/_impl.py | 68 +++++++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 16 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 29ac08c..6cba00a 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -40,7 +40,9 @@ from types import TracebackType from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Coroutine, Sequence -_IS_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.')[:2])) < (0, 22) # type: ignore[attr-defined] +_IS_TRIO_MULTI_ERROR = tuple( + map(int, trio.__version__.split(".")[:2]) # type: ignore[attr-defined] +) < (0, 22) if _IS_TRIO_MULTI_ERROR: _TRIO_EXC_GROUP_TYPE = trio.MultiError # type: ignore[attr-defined] # pylint: disable=no-member @@ -464,7 +466,10 @@ async def connect_websocket_url( max_message_size=max_message_size) -def _url_to_host(url: str, ssl_context: ssl.SSLContext | None) -> tuple[str, int, str, ssl.SSLContext | bool]: +def _url_to_host( + url: str, + ssl_context: ssl.SSLContext | None, +) -> tuple[str, int, str, ssl.SSLContext | bool]: ''' Convert a WebSocket URL to a (host,port,resource) tuple. @@ -597,7 +602,7 @@ async def serve_websocket( disconnect_timeout: float = CONN_TIMEOUT, task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED, ) -> NoReturn: - ''' + """ Serve a WebSocket over TCP. This function supports the Trio nursery start protocol: ``server = await @@ -631,18 +636,33 @@ async def serve_websocket( to finish the closing handshake before timing out. :param task_status: Part of Trio nursery start protocol. :returns: This function runs until cancelled. - ''' - open_tcp_listeners: partial[Coroutine[Any, Any, list[trio.SocketListener]]] | partial[Coroutine[Any, Any, list[trio.SSLListener[trio.SocketStream]]]] + """ + open_tcp_listeners: ( + partial[Coroutine[Any, Any, list[trio.SocketListener]]] + | partial[Coroutine[Any, Any, list[trio.SSLListener[trio.SocketStream]]]] + ) if ssl_context is None: open_tcp_listeners = partial(trio.open_tcp_listeners, port, host=host) else: - open_tcp_listeners = partial(trio.open_ssl_over_tcp_listeners, port, - ssl_context, host=host, https_compatible=True) - listeners: list[trio.SSLListener[trio.SocketStream]] | list[trio.SocketListener] = await open_tcp_listeners() - server = WebSocketServer(handler, listeners, - handler_nursery=handler_nursery, message_queue_size=message_queue_size, - max_message_size=max_message_size, connect_timeout=connect_timeout, - disconnect_timeout=disconnect_timeout) + open_tcp_listeners = partial( + trio.open_ssl_over_tcp_listeners, + port, + ssl_context, + host=host, + https_compatible=True, + ) + listeners: list[trio.SSLListener[trio.SocketStream]] | list[trio.SocketListener] = ( + await open_tcp_listeners() + ) + server = WebSocketServer( + handler, + listeners, + handler_nursery=handler_nursery, + message_queue_size=message_queue_size, + max_message_size=max_message_size, + connect_timeout=connect_timeout, + disconnect_timeout=disconnect_timeout, + ) await server.run(task_status=task_status) @@ -683,7 +703,12 @@ class ConnectionRejected(HandshakeError): A WebSocket connection could not be established because the server rejected the connection attempt. ''' - def __init__(self, status_code: int, headers: tuple[tuple[bytes, bytes], ...], body: bytes | None): + def __init__( + self, + status_code: int, + headers: tuple[tuple[bytes, bytes], ...], + body: bytes | None, + ) -> None: ''' Constructor. @@ -780,7 +805,11 @@ class WebSocketRequest: The server may modify the handshake or leave it as is. The server should call ``accept()`` to finish the handshake and obtain a connection object. ''' - def __init__(self, connection: WebSocketConnection, event: wsproto.events.Request) -> None: + def __init__( + self, + connection: WebSocketConnection, + event: wsproto.events.Request, + ) -> None: ''' Constructor. @@ -1374,7 +1403,10 @@ async def _handle_close_connection_event(self, event: wsproto.events.CloseConnec if self.is_server: await self._close_stream() - async def _handle_message_event(self, event: wsproto.events.BytesMessage | wsproto.events.TextMessage) -> None: + async def _handle_message_event( + self, + event: wsproto.events.BytesMessage | wsproto.events.TextMessage, + ) -> None: ''' Handle a message event. @@ -1661,7 +1693,11 @@ def listeners(self) -> list[Endpoint | str]: listeners.append(repr(listener)) return listeners - async def run(self, *, task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED) -> NoReturn: + async def run( + self, + *, + task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED, + ) -> NoReturn: ''' Start serving incoming connections requests. From 45c06ad57813d54f5eca60526b30afe057b4de9b Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Mon, 13 Jan 2025 01:12:02 -0600 Subject: [PATCH 23/37] Fix remaining CI issues --- autobahn/client.py | 13 +- autobahn/server.py | 6 +- examples/client.py | 30 ++- examples/generate-cert.py | 2 +- examples/server.py | 8 +- mypy.ini | 15 +- tests/test_connection.py | 428 ++++++++++++++++++++++++------------- trio_websocket/__init__.py | 37 ++-- trio_websocket/_impl.py | 4 +- 9 files changed, 345 insertions(+), 198 deletions(-) diff --git a/autobahn/client.py b/autobahn/client.py index d93be1c..1d8a082 100644 --- a/autobahn/client.py +++ b/autobahn/client.py @@ -6,6 +6,7 @@ import json import logging import sys +from typing import Any import trio from trio_websocket import open_websocket_url, ConnectionClosed @@ -17,7 +18,7 @@ logger = logging.getLogger('client') -async def get_case_count(url): +async def get_case_count(url: str) -> int: url = url + '/getCaseCount' async with open_websocket_url(url) as conn: case_count = await conn.get_message() @@ -25,13 +26,13 @@ async def get_case_count(url): return int(case_count) -async def get_case_info(url, case): +async def get_case_info(url: str, case: str) -> Any: url = f'{url}/getCaseInfo?case={case}' async with open_websocket_url(url) as conn: return json.loads(await conn.get_message()) -async def run_case(url, case): +async def run_case(url: str, case: str) -> None: url = f'{url}/runCase?case={case}&agent={AGENT}' try: async with open_websocket_url(url, max_message_size=MAX_MESSAGE_SIZE) as conn: @@ -42,7 +43,7 @@ async def run_case(url, case): pass -async def update_reports(url): +async def update_reports(url: str) -> None: url = f'{url}/updateReports?agent={AGENT}' async with open_websocket_url(url) as conn: # This command runs as soon as we connect to it, so we don't need to @@ -50,7 +51,7 @@ async def update_reports(url): pass -async def run_tests(args): +async def run_tests(args: argparse.Namespace) -> None: logger = logging.getLogger('trio-websocket') if args.debug_cases: # Don't fetch case count when debugging a subset of test cases. It adds @@ -82,7 +83,7 @@ async def run_tests(args): sys.exit(1) -def parse_args(): +def parse_args() -> argparse.Namespace: ''' Parse command line arguments. ''' parser = argparse.ArgumentParser(description='Autobahn client for' ' trio-websocket') diff --git a/autobahn/server.py b/autobahn/server.py index ff23846..6a84de4 100644 --- a/autobahn/server.py +++ b/autobahn/server.py @@ -23,14 +23,14 @@ connection_count = 0 -async def main(): +async def main() -> None: ''' Main entry point. ''' logger.info('Starting websocket server on ws://%s:%d', BIND_IP, BIND_PORT) await serve_websocket(handler, BIND_IP, BIND_PORT, ssl_context=None, max_message_size=MAX_MESSAGE_SIZE) -async def handler(request: WebSocketRequest): +async def handler(request: WebSocketRequest) -> None: ''' Reverse incoming websocket messages and send them back. ''' global connection_count # pylint: disable=global-statement connection_count += 1 @@ -46,7 +46,7 @@ async def handler(request: WebSocketRequest): logger.exception(' runtime exception handling connection #%d', connection_count) -def parse_args(): +def parse_args() -> argparse.Namespace: ''' Parse command line arguments. ''' parser = argparse.ArgumentParser(description='Autobahn server for' ' trio-websocket') diff --git a/examples/client.py b/examples/client.py index 030c12b..ba5311c 100644 --- a/examples/client.py +++ b/examples/client.py @@ -11,16 +11,23 @@ import ssl import sys import urllib.parse +from typing import NoReturn import trio -from trio_websocket import open_websocket_url, ConnectionClosed, HandshakeError +from trio_websocket import ( + open_websocket_url, + ConnectionClosed, + HandshakeError, + WebSocketConnection, + CloseReason, +) logging.basicConfig(level=logging.DEBUG) here = pathlib.Path(__file__).parent -def commands(): +def commands() -> None: ''' Print the supported commands. ''' print('Commands: ') print('send -> send message') @@ -29,7 +36,7 @@ def commands(): print() -def parse_args(): +def parse_args() -> argparse.Namespace: ''' Parse command line arguments. ''' parser = argparse.ArgumentParser(description='Example trio-websocket client') parser.add_argument('--heartbeat', action='store_true', @@ -38,7 +45,7 @@ def parse_args(): return parser.parse_args() -async def main(args): +async def main(args: argparse.Namespace) -> bool: ''' Main entry point, returning False in the case of logged error. ''' if urllib.parse.urlsplit(args.url).scheme == 'wss': # Configure SSL context to handle our self-signed certificate. Most @@ -59,9 +66,10 @@ async def main(args): except HandshakeError as e: logging.error('Connection attempt failed: %s', e) return False + return True -async def handle_connection(ws, use_heartbeat): +async def handle_connection(ws: WebSocketConnection, use_heartbeat: bool) -> None: ''' Handle the connection. ''' logging.debug('Connected!') try: @@ -71,11 +79,12 @@ async def handle_connection(ws, use_heartbeat): nursery.start_soon(get_commands, ws) nursery.start_soon(get_messages, ws) except ConnectionClosed as cc: + assert isinstance(cc.reason, CloseReason) reason = '' if cc.reason.reason is None else f'"{cc.reason.reason}"' print(f'Closed: {cc.reason.code}/{cc.reason.name} {reason}') -async def heartbeat(ws, timeout, interval): +async def heartbeat(ws: WebSocketConnection, timeout: float, interval: float) -> NoReturn: ''' Send periodic pings on WebSocket ``ws``. @@ -99,11 +108,10 @@ async def heartbeat(ws, timeout, interval): await trio.sleep(interval) -async def get_commands(ws): +async def get_commands(ws: WebSocketConnection) -> None: ''' In a loop: get a command from the user and execute it. ''' while True: - cmd = await trio.to_thread.run_sync(input, 'cmd> ', - cancellable=True) + cmd = await trio.to_thread.run_sync(input, 'cmd> ') if cmd.startswith('ping'): payload = cmd[5:].encode('utf8') or None await ws.ping(payload) @@ -123,11 +131,11 @@ async def get_commands(ws): await trio.sleep(0.25) -async def get_messages(ws): +async def get_messages(ws: WebSocketConnection) -> None: ''' In a loop: get a WebSocket message and print it out. ''' while True: message = await ws.get_message() - print(f'message: {message}') + print(f'message: {message!r}') if __name__ == '__main__': diff --git a/examples/generate-cert.py b/examples/generate-cert.py index cc21698..fcb36bd 100644 --- a/examples/generate-cert.py +++ b/examples/generate-cert.py @@ -3,7 +3,7 @@ import trustme -def main(): +def main() -> None: here = pathlib.Path(__file__).parent ca_path = here / 'fake.ca.pem' server_path = here / 'fake.server.pem' diff --git a/examples/server.py b/examples/server.py index 611d89b..5274013 100644 --- a/examples/server.py +++ b/examples/server.py @@ -14,7 +14,7 @@ import ssl import trio -from trio_websocket import serve_websocket, ConnectionClosed +from trio_websocket import serve_websocket, ConnectionClosed, WebSocketRequest logging.basicConfig(level=logging.DEBUG) @@ -22,7 +22,7 @@ here = pathlib.Path(__file__).parent -def parse_args(): +def parse_args() -> argparse.Namespace: ''' Parse command line arguments. ''' parser = argparse.ArgumentParser(description='Example trio-websocket client') parser.add_argument('--ssl', action='store_true', help='Use SSL') @@ -32,7 +32,7 @@ def parse_args(): return parser.parse_args() -async def main(args): +async def main(args: argparse.Namespace) -> None: ''' Main entry point. ''' logging.info('Starting websocket server…') if args.ssl: @@ -48,7 +48,7 @@ async def main(args): await serve_websocket(handler, host, args.port, ssl_context) -async def handler(request): +async def handler(request: WebSocketRequest) -> None: ''' Reverse incoming websocket messages and send them back. ''' logging.info('Handler starting on path "%s"', request.path) ws = await request.accept() diff --git a/mypy.ini b/mypy.ini index 92e5727..09cd26d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,5 +1,13 @@ [mypy] -files = trio_websocket +explicit_package_bases = true +files = trio_websocket,tests,autobahn,examples +show_column_numbers = true +show_error_codes = true +show_traceback = true +warn_redundant_casts = true +warn_unused_configs = true + +[mypy-trio_websocket] check_untyped_defs = true disallow_any_decorated = true disallow_any_generics = true @@ -13,13 +21,8 @@ ignore_missing_imports = true local_partial_types = true no_implicit_optional = true no_implicit_reexport = true -show_column_numbers = true -show_error_codes = true -show_traceback = true strict = true strict_equality = true -warn_redundant_casts = true warn_return_any = true warn_unreachable = true -warn_unused_configs = true warn_unused_ignores = true diff --git a/tests/test_connection.py b/tests/test_connection.py index 0837aa5..db1bf69 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -32,11 +32,13 @@ from __future__ import annotations import copy -from functools import partial, wraps import re import ssl import sys -from unittest.mock import patch +from collections.abc import AsyncGenerator +from functools import partial, wraps +from typing import TYPE_CHECKING, TypeVar, cast +from unittest.mock import Mock, patch import attr import pytest @@ -58,30 +60,38 @@ except ImportError: pass + from trio_websocket import ( - connect_websocket, - connect_websocket_url, + CloseReason, ConnectionClosed, ConnectionRejected, ConnectionTimeout, DisconnectionTimeout, Endpoint, HandshakeError, + WebSocketConnection, + WebSocketRequest, + WebSocketServer, + connect_websocket, + connect_websocket_url, open_websocket, open_websocket_url, serve_websocket, - WebSocketConnection, - WebSocketServer, - WebSocketRequest, wrap_client_stream, - wrap_server_stream + wrap_server_stream, ) - from trio_websocket._impl import _TRIO_EXC_GROUP_TYPE if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + from wsproto.events import Event + + from typing_extensions import ParamSpec + PS = ParamSpec("PS") + WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.'))) HOST = '127.0.0.1' @@ -96,19 +106,21 @@ FORCE_TIMEOUT = 2 TIMEOUT_TEST_MAX_DURATION = 3 +T = TypeVar("T") + @pytest.fixture -async def echo_server(nursery): +async def echo_server(nursery: trio.Nursery) -> AsyncGenerator[WebSocketServer, None]: ''' A server that reads one message, sends back the same message, then closes the connection. ''' serve_fn = partial(serve_websocket, echo_request_handler, HOST, 0, ssl_context=None) server = await nursery.start(serve_fn) - yield server + yield cast(WebSocketServer, server) @pytest.fixture -async def echo_conn(echo_server): +async def echo_conn(echo_server: WebSocketServer) -> AsyncGenerator[WebSocketConnection, None]: ''' Return a client connection instance that is connected to an echo server. ''' async with open_websocket(HOST, echo_server.port, RESOURCE, @@ -116,7 +128,7 @@ async def echo_conn(echo_server): yield conn -async def echo_request_handler(request): +async def echo_request_handler(request: WebSocketRequest) -> None: ''' Accept incoming request and then pass off to echo connection handler. ''' @@ -131,35 +143,56 @@ async def echo_request_handler(request): class fail_after: ''' This decorator fails if the runtime of the decorated function (as measured by the Trio clock) exceeds the specified value. ''' - def __init__(self, seconds): + def __init__(self, seconds: int) -> None: self._seconds = seconds - def __call__(self, fn): + def __call__(self, fn: Callable[PS, Awaitable[T]]) -> Callable[PS, Awaitable[T | None]]: @wraps(fn) - async def wrapper(*args, **kwargs): + async def wrapper(*args: PS.args, **kwargs: PS.kwargs) -> T | None: + result: T | None = None with trio.move_on_after(self._seconds) as cancel_scope: - await fn(*args, **kwargs) + result = await fn(*args, **kwargs) if cancel_scope.cancelled_caught: pytest.fail(f'Test runtime exceeded the maximum {self._seconds} seconds') + return result return wrapper @attr.s(hash=False, eq=False) -class MemoryListener(trio.abc.Listener): - closed = attr.ib(default=False) +class MemoryListener( + trio.abc.Listener[ + trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream] + ] +): + closed: bool = attr.ib(default=False) accepted_streams: list[ - tuple[trio.abc.SendChannel[str], trio.abc.ReceiveChannel[str]] + trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream] ] = attr.ib(factory=list) - queued_streams = attr.ib(factory=lambda: trio.open_memory_channel[str](1)) - accept_hook = attr.ib(default=None) - - async def connect(self): + queued_streams: tuple[ + trio.MemorySendChannel[ + trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream] + ], + trio.MemoryReceiveChannel[ + trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream] + ], + ] = attr.ib(factory=lambda: trio.open_memory_channel[ + trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream] + ](1)) + accept_hook: Callable[[], Awaitable[object]] | None = attr.ib(default=None) + + async def connect(self) -> trio.StapledStream[ + trio.testing.MemorySendStream, + trio.testing.MemoryReceiveStream, + ]: assert not self.closed client, server = memory_stream_pair() await self.queued_streams[0].send(server) return client - async def accept(self): + async def accept(self) -> trio.StapledStream[ + trio.testing.MemorySendStream, + trio.testing.MemoryReceiveStream, + ]: await trio.sleep(0) assert not self.closed if self.accept_hook is not None: @@ -168,12 +201,12 @@ async def accept(self): self.accepted_streams.append(stream) return stream - async def aclose(self): + async def aclose(self) -> None: self.closed = True await trio.sleep(0) -async def test_endpoint_ipv4(): +async def test_endpoint_ipv4() -> None: e1 = Endpoint('10.105.0.2', 80, False) assert e1.url == 'ws://10.105.0.2' assert str(e1) == 'Endpoint(address="10.105.0.2", port=80, is_ssl=False)' @@ -185,7 +218,7 @@ async def test_endpoint_ipv4(): assert str(e3) == 'Endpoint(address="0.0.0.0", port=443, is_ssl=True)' -async def test_listen_port_ipv6(): +async def test_listen_port_ipv6() -> None: e1 = Endpoint('2599:8807:6201:b7:16cf:bb9c:a6d3:51ab', 80, False) assert e1.url == 'ws://[2599:8807:6201:b7:16cf:bb9c:a6d3:51ab]' assert str(e1) == 'Endpoint(address="2599:8807:6201:b7:16cf:bb9c:a6d3' \ @@ -198,17 +231,19 @@ async def test_listen_port_ipv6(): assert str(e3) == 'Endpoint(address="::", port=443, is_ssl=True)' -async def test_server_has_listeners(nursery): +async def test_server_has_listeners(nursery: trio.Nursery) -> None: server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) assert len(server.listeners) > 0 assert isinstance(server.listeners[0], Endpoint) -async def test_serve(nursery): +async def test_serve(nursery: trio.Nursery) -> None: task = current_task() server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) port = server.port assert server.port != 0 # The server nursery begins with one task (server.listen). @@ -221,7 +256,7 @@ async def test_serve(nursery): assert len(task.child_nurseries) == no_clients_nursery_count + 1 -async def test_serve_ssl(nursery): +async def test_serve_ssl(nursery: trio.Nursery) -> None: server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) client_context = ssl.create_default_context() ca = trustme.CA() @@ -231,19 +266,23 @@ async def test_serve_ssl(nursery): server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, server_context) + assert isinstance(server, WebSocketServer) port = server.port async with open_websocket(HOST, port, RESOURCE, use_ssl=client_context ) as conn: assert not conn.closed + assert isinstance(conn.local, Endpoint) assert conn.local.is_ssl + assert isinstance(conn.remote, Endpoint) assert conn.remote.is_ssl -async def test_serve_handler_nursery(nursery): +async def test_serve_handler_nursery(nursery: trio.Nursery) -> None: async with trio.open_nursery() as handler_nursery: serve_with_nursery = partial(serve_websocket, echo_request_handler, HOST, 0, None, handler_nursery=handler_nursery) server = await nursery.start(serve_with_nursery) + assert isinstance(server, WebSocketServer) port = server.port # The server nursery begins with one task (server.listen). assert len(nursery.child_tasks) == 1 @@ -253,25 +292,35 @@ async def test_serve_handler_nursery(nursery): assert len(handler_nursery.child_tasks) == 1 -async def test_serve_with_zero_listeners(): +async def test_serve_with_zero_listeners() -> None: with pytest.raises(ValueError): WebSocketServer(echo_request_handler, []) -async def test_serve_non_tcp_listener(nursery): +async def test_serve_non_tcp_listener(nursery: trio.Nursery) -> None: listeners = [MemoryListener()] - server = WebSocketServer(echo_request_handler, listeners) + server = WebSocketServer( + echo_request_handler, + listeners, # type: ignore[arg-type] + ) await nursery.start(server.run) assert len(server.listeners) == 1 with pytest.raises(RuntimeError): server.port # pylint: disable=pointless-statement - assert server.listeners[0].startswith('MemoryListener(') + listener = server.listeners[0] + assert isinstance(listener, str) + assert listener.startswith('MemoryListener(') -async def test_serve_multiple_listeners(nursery): +async def test_serve_multiple_listeners(nursery: trio.Nursery) -> None: listener1 = (await trio.open_tcp_listeners(0, host=HOST))[0] listener2 = MemoryListener() - server = WebSocketServer(echo_request_handler, [listener1, listener2]) + server = WebSocketServer( + echo_request_handler, [ + listener1, + listener2, # type: ignore[list-item] + ] + ) await nursery.start(server.run) assert len(server.listeners) == 2 with pytest.raises(RuntimeError): @@ -279,13 +328,17 @@ async def test_serve_multiple_listeners(nursery): # usable if you have exactly one listener. server.port # pylint: disable=pointless-statement # The first listener metadata is a ListenPort instance. - assert server.listeners[0].port != 0 + listener_zero = server.listeners[0] + assert isinstance(listener_zero, Endpoint) + assert listener_zero.port != 0 # The second listener metadata is a string containing the repr() of a # MemoryListener object. - assert server.listeners[1].startswith('MemoryListener(') + listener_one = server.listeners[1] + assert isinstance(listener_one, str) + assert listener_one.startswith('MemoryListener(') -async def test_client_open(echo_server): +async def test_client_open(echo_server: WebSocketServer) -> None: async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) \ as conn: assert not conn.closed @@ -299,27 +352,33 @@ async def test_client_open(echo_server): (RESOURCE + '/path', RESOURCE + '/path'), (RESOURCE + '?foo=bar', RESOURCE + '?foo=bar') ]) -async def test_client_open_url(path, expected_path, echo_server): +async def test_client_open_url(path: str, expected_path: str, echo_server: WebSocketServer) -> None: url = f'ws://{HOST}:{echo_server.port}{path}' async with open_websocket_url(url) as conn: assert conn.path == expected_path -async def test_client_open_invalid_url(echo_server): +async def test_client_open_invalid_url(echo_server: WebSocketServer) -> None: with pytest.raises(ValueError): async with open_websocket_url('http://foo.com/bar'): pass -async def test_client_open_invalid_ssl(echo_server, nursery): +async def test_client_open_invalid_ssl( + echo_server: WebSocketServer, + nursery: trio.Nursery, +) -> None: with pytest.raises(TypeError, match='`use_ssl` argument must be bool or ssl.SSLContext'): - await connect_websocket(nursery, HOST, echo_server.port, RESOURCE, use_ssl=1) + await connect_websocket( + nursery, HOST, echo_server.port, RESOURCE, + use_ssl=1, # type: ignore[arg-type] + ) url = f'ws://{HOST}:{echo_server.port}{RESOURCE}' with pytest.raises(ValueError, match='^SSL context must be None for ws: URL scheme$' ): await connect_websocket_url(nursery, url, ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) -async def test_ascii_encoded_path_is_ok(echo_server): +async def test_ascii_encoded_path_is_ok(echo_server: WebSocketServer) -> None: path = '%D7%90%D7%91%D7%90?%D7%90%D7%9E%D7%90' url = f'ws://{HOST}:{echo_server.port}{RESOURCE}/{path}' async with open_websocket_url(url) as conn: @@ -327,7 +386,7 @@ async def test_ascii_encoded_path_is_ok(echo_server): @patch('trio_websocket._impl.open_websocket') -def test_client_open_url_options(open_websocket_mock): +def test_client_open_url_options(open_websocket_mock: Mock) -> None: """open_websocket_url() must pass its options on to open_websocket()""" port = 1234 url = f'ws://{HOST}:{port}{RESOURCE}' @@ -339,7 +398,7 @@ def test_client_open_url_options(open_websocket_mock): 'connect_timeout': 36, 'disconnect_timeout': 37, } - open_websocket_url(url, **options) + open_websocket_url(url, **options) # type: ignore[arg-type] _, call_args, call_kwargs = open_websocket_mock.mock_calls[0] assert call_args == (HOST, port, RESOURCE) assert not call_kwargs.pop('use_ssl') @@ -350,19 +409,19 @@ def test_client_open_url_options(open_websocket_mock): assert call_kwargs['use_ssl'] -async def test_client_connect(echo_server, nursery): +async def test_client_connect(echo_server: WebSocketServer, nursery: trio.Nursery) -> None: conn = await connect_websocket(nursery, HOST, echo_server.port, RESOURCE, use_ssl=False) assert not conn.closed -async def test_client_connect_url(echo_server, nursery): +async def test_client_connect_url(echo_server: WebSocketServer, nursery: trio.Nursery) -> None: url = f'ws://{HOST}:{echo_server.port}{RESOURCE}' conn = await connect_websocket_url(nursery, url) assert not conn.closed -async def test_connection_has_endpoints(echo_conn): +async def test_connection_has_endpoints(echo_conn: WebSocketConnection) -> None: async with echo_conn: assert isinstance(echo_conn.local, Endpoint) assert str(echo_conn.local.address) == HOST @@ -376,47 +435,53 @@ async def test_connection_has_endpoints(echo_conn): @fail_after(1) -async def test_handshake_has_endpoints(nursery): - async def handler(request): +async def test_handshake_has_endpoints(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: + assert isinstance(server, WebSocketServer) + assert isinstance(request.local, Endpoint) assert str(request.local.address) == HOST assert request.local.port == server.port assert not request.local.is_ssl + assert isinstance(request.remote, Endpoint) assert str(request.remote.address) == HOST assert not request.remote.is_ssl await request.accept() server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): pass -async def test_handshake_subprotocol(nursery): - async def handler(request): +async def test_handshake_subprotocol(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: assert request.proposed_subprotocols == ('chat', 'file') server_ws = await request.accept(subprotocol='chat') assert server_ws.subprotocol == 'chat' server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, subprotocols=('chat', 'file')) as client_ws: assert client_ws.subprotocol == 'chat' -async def test_handshake_path(nursery): - async def handler(request): +async def test_handshake_path(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: assert request.path == RESOURCE server_ws = await request.accept() assert server_ws.path == RESOURCE server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, ) as client_ws: assert client_ws.path == RESOURCE @fail_after(1) -async def test_handshake_client_headers(nursery): - async def handler(request): +async def test_handshake_client_headers(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: headers = dict(request.headers) assert b'x-test-header' in headers assert headers[b'x-test-header'] == b'My test header' @@ -424,6 +489,7 @@ async def handler(request): await server_ws.send_message('test') server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) headers = [(b'X-Test-Header', b'My test header')] async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, extra_headers=headers) as client_ws: @@ -431,12 +497,13 @@ async def handler(request): @fail_after(1) -async def test_handshake_server_headers(nursery): - async def handler(request): - headers = [('X-Test-Header', 'My test header')] +async def test_handshake_server_headers(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: + headers = [(b'X-Test-Header', b'My test header')] await request.accept(extra_headers=headers) server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False ) as client_ws: header_key, header_value = client_ws.handshake_headers[0] @@ -444,22 +511,25 @@ async def handler(request): assert header_value == b'My test header' - - @fail_after(5) -async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock): +async def test_open_websocket_internal_ki( + nursery: trio.Nursery, + monkeypatch: pytest.MonkeyPatch, + autojump_clock: trio.testing.MockClock, +) -> None: """_reader_task._handle_ping_event triggers KeyboardInterrupt. user code also raises exception. Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup """ - async def ki_raising_ping_handler(*args, **kwargs) -> None: + async def ki_raising_ping_handler(*args: object, **kwargs: object) -> None: raise KeyboardInterrupt monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler) - async def handler(request): + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() await server_ws.ping(b"a") server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) with pytest.raises(KeyboardInterrupt) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): with trio.fail_after(1) as cs: @@ -471,7 +541,11 @@ async def handler(request): assert any(isinstance(e, trio.TooSlowError) for e in e_cause.exceptions) @fail_after(5) -async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock): +async def test_open_websocket_internal_exc( + nursery: trio.Nursery, + monkeypatch: pytest.MonkeyPatch, + autojump_clock: trio.testing.MockClock, +) -> None: """_reader_task._handle_ping_event triggers ValueError. user code also raises exception. internal exception is in __context__ exceptiongroup and user exc is delivered @@ -480,15 +554,16 @@ async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock) internal_error.__context__ = TypeError() user_error = NameError() user_error_context = KeyError() - async def raising_ping_event(*args, **kwargs) -> None: + async def raising_ping_event(*args: object, **kwargs: object) -> None: raise internal_error monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event) - async def handler(request): + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() await server_ws.ping(b"a") server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) with pytest.raises(type(user_error)) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): await trio.lowlevel.checkpoint() @@ -502,19 +577,23 @@ async def handler(request): assert user_error_context in e_context.exceptions @fail_after(5) -async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock): +async def test_open_websocket_cancellations( + nursery: trio.Nursery, + monkeypatch: pytest.MonkeyPatch, + autojump_clock: trio.testing.MockClock, +) -> None: """Both user code and _reader_task raise Cancellation. Check that open_websocket reraises the one from user code for traceback reasons. """ - async def sleeping_ping_event(*args, **kwargs) -> None: + async def sleeping_ping_event(*args: object, **kwargs: object) -> None: await trio.sleep_forever() # We monkeypatch WebSocketConnection._handle_ping_event to ensure it will actually # raise Cancelled upon being cancelled. For some reason it doesn't otherwise. monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", sleeping_ping_event) - async def handler(request): + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() await server_ws.ping(b"a") user_cancelled = None @@ -522,6 +601,7 @@ async def handler(request): user_cancelled_context = None server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) with trio.move_on_after(2): with pytest.raises(trio.Cancelled) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): @@ -538,15 +618,16 @@ async def handler(request): assert exc_info.value.__context__ is user_cancelled_context def _trio_default_non_strict_exception_groups() -> bool: - assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme" - return int(trio.__version__[2:4]) < 25 + version = trio.__version__ # type: ignore[attr-defined] + assert re.match(r'^0\.\d\d\.', version), "unexpected trio versioning scheme" + return int(version[2:4]) < 25 @fail_after(1) async def test_handshake_exception_before_accept() -> None: ''' In #107, a request handler that throws an exception before finishing the handshake causes the task to hang. The proper behavior is to raise an exception to the nursery as soon as possible. ''' - async def handler(request): + async def handler(request: WebSocketRequest) -> None: raise ValueError() # pylint fails to resolve that BaseExceptionGroup will always be available @@ -554,6 +635,7 @@ async def handler(request): async with trio.open_nursery() as nursery: server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): pass @@ -572,10 +654,11 @@ async def handler(request): RaisesGroup(ValueError)))).matches(exc.value) -async def test_user_exception_cause(nursery) -> None: - async def handler(request): +async def test_user_exception_cause(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: await request.accept() server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) e_context = TypeError("foo") e_primary = ValueError("bar") e_cause = RuntimeError("zee") @@ -591,12 +674,13 @@ async def handler(request): assert e.__context__ is e_context @fail_after(1) -async def test_reject_handshake(nursery): - async def handler(request): +async def test_reject_handshake(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: body = b'My body' await request.reject(400, body=body) server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) with pytest.raises(ConnectionRejected) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): pass @@ -605,17 +689,18 @@ async def handler(request): @fail_after(1) -async def test_reject_handshake_invalid_info_status(nursery): +async def test_reject_handshake_invalid_info_status(nursery: trio.Nursery) -> None: ''' An informational status code that is not 101 should cause the client to reject the handshake. Since it is an informational response, there will not be a response body, so this test exercises a different code path. ''' - async def handler(stream): + async def handler(stream: trio.SocketStream) -> None: await stream.send_all(b'HTTP/1.1 100 CONTINUE\r\n\r\n') await stream.receive_some(max_bytes=1024) serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) - listeners = await nursery.start(serve_fn) + raw_listeners = await nursery.start(serve_fn) + listeners = cast(list[trio.SocketListener], raw_listeners) port = listeners[0].socket.getsockname()[1] with pytest.raises(ConnectionRejected) as exc_info: @@ -627,7 +712,7 @@ async def handler(stream): assert exc.body is None -async def test_handshake_protocol_error(echo_server): +async def test_handshake_protocol_error(echo_server: WebSocketServer) -> None: ''' If a client connects to a trio-websocket server and tries to speak HTTP instead of WebSocket, the server should reject the connection. (If the @@ -641,29 +726,29 @@ async def test_handshake_protocol_error(echo_server): assert response.startswith(b'HTTP/1.1 400') -async def test_client_send_and_receive(echo_conn): +async def test_client_send_and_receive(echo_conn: WebSocketConnection) -> None: async with echo_conn: await echo_conn.send_message('This is a test message.') received_msg = await echo_conn.get_message() assert received_msg == 'This is a test message.' -async def test_client_send_invalid_type(echo_conn): +async def test_client_send_invalid_type(echo_conn: WebSocketConnection) -> None: async with echo_conn: with pytest.raises(ValueError): - await echo_conn.send_message(object()) + await echo_conn.send_message(object()) # type: ignore[arg-type] -async def test_client_ping(echo_conn): +async def test_client_ping(echo_conn: WebSocketConnection) -> None: async with echo_conn: await echo_conn.ping(b'A') with pytest.raises(ConnectionClosed): await echo_conn.ping(b'B') -async def test_client_ping_two_payloads(echo_conn): +async def test_client_ping_two_payloads(echo_conn: WebSocketConnection) -> None: pong_count = 0 - async def ping_and_count(): + async def ping_and_count() -> None: nonlocal pong_count await echo_conn.ping() pong_count += 1 @@ -674,12 +759,12 @@ async def ping_and_count(): assert pong_count == 2 -async def test_client_ping_same_payload(echo_conn): +async def test_client_ping_same_payload(echo_conn: WebSocketConnection) -> None: # This test verifies that two tasks can't ping with the same payload at the # same time. One of them should succeed and the other should get an # exception. exc_count = 0 - async def ping_and_catch(): + async def ping_and_catch() -> None: nonlocal exc_count try: await echo_conn.ping(b'A') @@ -692,47 +777,53 @@ async def ping_and_catch(): assert exc_count == 1 -async def test_client_pong(echo_conn): +async def test_client_pong(echo_conn: WebSocketConnection) -> None: async with echo_conn: await echo_conn.pong(b'A') with pytest.raises(ConnectionClosed): await echo_conn.pong(b'B') -async def test_client_default_close(echo_conn): +async def test_client_default_close(echo_conn: WebSocketConnection) -> None: async with echo_conn: assert not echo_conn.closed + assert isinstance(echo_conn.closed, CloseReason) assert echo_conn.closed.code == 1000 assert echo_conn.closed.reason is None assert repr(echo_conn.closed) == 'CloseReason' -async def test_client_nondefault_close(echo_conn): +async def test_client_nondefault_close(echo_conn: WebSocketConnection) -> None: async with echo_conn: assert not echo_conn.closed await echo_conn.aclose(code=1001, reason='test reason') + assert isinstance(echo_conn.closed, CloseReason) assert echo_conn.closed.code == 1001 assert echo_conn.closed.reason == 'test reason' -async def test_wrap_client_stream(nursery): +async def test_wrap_client_stream(nursery: trio.Nursery) -> None: listener = MemoryListener() - server = WebSocketServer(echo_request_handler, [listener]) + server = WebSocketServer(echo_request_handler, [listener]) # type: ignore[list-item] await nursery.start(server.run) stream = await listener.connect() - conn = await wrap_client_stream(nursery, stream, HOST, RESOURCE) + conn = await wrap_client_stream( + nursery, + stream, # type: ignore[arg-type] + HOST, RESOURCE) async with conn: assert not conn.closed await conn.send_message('Hello from client!') msg = await conn.get_message() assert msg == 'Hello from client!' + assert isinstance(conn.local, str) assert conn.local.startswith('StapledStream(') assert conn.closed -async def test_wrap_server_stream(nursery): - async def handler(stream): +async def test_wrap_server_stream(nursery: trio.Nursery) -> None: + async def handler(stream: trio.SocketStream) -> None: request = await wrap_server_stream(nursery, stream) server_ws = await request.accept() async with server_ws: @@ -741,25 +832,30 @@ async def handler(stream): assert msg == 'Hello from client!' assert server_ws.closed serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) - listeners = await nursery.start(serve_fn) + raw_listeners = await nursery.start(serve_fn) + listeners = cast(list[trio.SocketListener], raw_listeners) port = listeners[0].socket.getsockname()[1] async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as client: await client.send_message('Hello from client!') @fail_after(TIMEOUT_TEST_MAX_DURATION) -async def test_client_open_timeout(nursery, autojump_clock): +async def test_client_open_timeout( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: ''' The client times out waiting for the server to complete the opening handshake. ''' - async def handler(request): + async def handler(request: WebSocketRequest) -> None: await trio.sleep(FORCE_TIMEOUT) await request.accept() pytest.fail('Should not reach this line.') server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) with pytest.raises(ConnectionTimeout): async with open_websocket(HOST, server.port, '/', use_ssl=False, @@ -768,7 +864,10 @@ async def handler(request): @fail_after(TIMEOUT_TEST_MAX_DURATION) -async def test_client_close_timeout(nursery, autojump_clock): +async def test_client_close_timeout( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: ''' This client times out waiting for the server to complete the closing handshake. @@ -778,7 +877,7 @@ async def test_client_close_timeout(nursery, autojump_clock): server's reader so it won't do the closing handshake for at least ``FORCE_TIMEOUT`` seconds. ''' - async def handler(request): + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() await trio.sleep(FORCE_TIMEOUT) # The next line should raise ConnectionClosed. @@ -788,6 +887,7 @@ async def handler(request): server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None, message_queue_size=0)) + assert isinstance(server, WebSocketServer) with pytest.raises(DisconnectionTimeout): async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, @@ -795,7 +895,7 @@ async def handler(request): await client_ws.send_message('test') -async def test_client_connect_networking_error(): +async def test_client_connect_networking_error() -> None: with patch('trio_websocket._impl.connect_websocket') as \ connect_websocket_mock: connect_websocket_mock.side_effect = OSError() @@ -805,7 +905,7 @@ async def test_client_connect_networking_error(): @fail_after(TIMEOUT_TEST_MAX_DURATION) -async def test_server_open_timeout(autojump_clock): +async def test_server_open_timeout(autojump_clock: trio.testing.MockClock) -> None: ''' The server times out waiting for the client to complete the opening handshake. @@ -814,12 +914,13 @@ async def test_server_open_timeout(autojump_clock): in an internal nursery and sending exceptions wouldn't be helpful. Instead, timed out tasks silently end. ''' - async def handler(request): + async def handler(request: WebSocketRequest) -> None: pytest.fail('This handler should not be called.') async with trio.open_nursery() as nursery: server = await nursery.start(partial(serve_websocket, handler, HOST, 0, ssl_context=None, handler_nursery=nursery, connect_timeout=TIMEOUT)) + assert isinstance(server, WebSocketServer) old_task_count = len(nursery.child_tasks) # This stream is not a WebSocket, so it won't send a handshake: @@ -837,7 +938,7 @@ async def handler(request): @fail_after(TIMEOUT_TEST_MAX_DURATION) -async def test_server_close_timeout(autojump_clock): +async def test_server_close_timeout(autojump_clock: trio.testing.MockClock) -> None: ''' The server times out waiting for the client to complete the closing handshake. @@ -850,7 +951,7 @@ async def test_server_close_timeout(autojump_clock): its message queue size is 0 and the server sends it exactly 1 message. This blocks the client's reader and prevents it from doing the client handshake. ''' - async def handler(request): + async def handler(request: WebSocketRequest) -> None: ws = await request.accept() # Send one message to block the client's reader task: await ws.send_message('test') @@ -859,6 +960,7 @@ async def handler(request): server = await outer.start(partial(serve_websocket, handler, HOST, 0, ssl_context=None, handler_nursery=outer, disconnect_timeout=TIMEOUT)) + assert isinstance(server, WebSocketServer) old_task_count = len(outer.child_tasks) # Spawn client inside an inner nursery so that we can cancel it's reader @@ -883,12 +985,13 @@ async def handler(request): outer.cancel_scope.cancel() -async def test_client_does_not_close_handshake(nursery): - async def handler(request): +async def test_client_does_not_close_handshake(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() with pytest.raises(ConnectionClosed): await server_ws.get_message() server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) stream = await trio.open_tcp_stream(HOST, server.port) client_ws = await wrap_client_stream(nursery, stream, HOST, RESOURCE) async with client_ws: @@ -897,10 +1000,10 @@ async def handler(request): await client_ws.send_message('Hello from client!') -async def test_server_sends_after_close(nursery): +async def test_server_sends_after_close(nursery: trio.Nursery) -> None: done = trio.Event() - async def handler(request): + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() with pytest.raises(ConnectionClosed): while True: @@ -908,6 +1011,7 @@ async def handler(request): done.set() server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) stream = await trio.open_tcp_stream(HOST, server.port) client_ws = await wrap_client_stream(nursery, stream, HOST, RESOURCE) async with client_ws: @@ -918,8 +1022,8 @@ async def handler(request): await done.wait() -async def test_server_does_not_close_handshake(nursery): - async def handler(stream): +async def test_server_does_not_close_handshake(nursery: trio.Nursery) -> None: + async def handler(stream: trio.SocketStream) -> None: request = await wrap_server_stream(nursery, stream) server_ws = await request.accept() async with server_ws: @@ -927,20 +1031,25 @@ async def handler(stream): with pytest.raises(ConnectionClosed): await server_ws.send_message('Hello from client!') serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) - listeners = await nursery.start(serve_fn) + raw_listeners = await nursery.start(serve_fn) + listeners = cast(list[trio.SocketListener], raw_listeners) port = listeners[0].socket.getsockname()[1] async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as client: with pytest.raises(ConnectionClosed): await client.get_message() -async def test_server_handler_exit(nursery, autojump_clock): - async def handler(request): +async def test_server_handler_exit( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: + async def handler(request: WebSocketRequest) -> None: await request.accept() await trio.sleep(1) server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) # connection should close when server handler exits with trio.fail_after(2): @@ -949,11 +1058,12 @@ async def handler(request): with pytest.raises(ConnectionClosed) as exc_info: await connection.get_message() exc = exc_info.value + assert isinstance(exc.reason, CloseReason) assert exc.reason.name == 'NORMAL_CLOSURE' @fail_after(DEFAULT_TEST_MAX_DURATION) -async def test_read_messages_after_remote_close(nursery): +async def test_read_messages_after_remote_close(nursery: trio.Nursery) -> None: ''' When the remote endpoint closes, the local endpoint can still read all of the messages sent prior to closing. Any attempt to read beyond that will @@ -963,7 +1073,7 @@ async def test_read_messages_after_remote_close(nursery): ''' server_closed = trio.Event() - async def handler(request): + async def handler(request: WebSocketRequest) -> None: server = await request.accept() async with server: await server.send_message('1') @@ -972,6 +1082,7 @@ async def handler(request): server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) # The client needs a message queue of size 2 so that it can buffer both # incoming messages without blocking the reader task. @@ -984,14 +1095,14 @@ async def handler(request): await client.get_message() -async def test_no_messages_after_local_close(nursery): +async def test_no_messages_after_local_close(nursery: trio.Nursery) -> None: ''' If the local endpoint initiates closing, then pending messages are discarded and any attempt to read a message will raise ConnectionClosed. ''' client_closed = trio.Event() - async def handler(request): + async def handler(request: WebSocketRequest) -> None: # The server sends some messages and then closes. server = await request.accept() async with server: @@ -1001,6 +1112,7 @@ async def handler(request): server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, '/', use_ssl=False) as client: pass @@ -1009,7 +1121,10 @@ async def handler(request): client_closed.set() -async def test_cm_exit_with_pending_messages(echo_server, autojump_clock): +async def test_cm_exit_with_pending_messages( + echo_server: WebSocketServer, + autojump_clock: trio.testing.MockClock, +) -> None: ''' Regression test for #74, where a context manager was not able to exit when there were pending messages in the receive queue. @@ -1023,13 +1138,13 @@ async def test_cm_exit_with_pending_messages(echo_server, autojump_clock): @fail_after(DEFAULT_TEST_MAX_DURATION) -async def test_max_message_size(nursery): +async def test_max_message_size(nursery: trio.Nursery) -> None: ''' Set the client's max message size to 100 bytes. The client can send a message larger than 100 bytes, but when it receives a message larger than 100 bytes, it closes the connection with code 1009. ''' - async def handler(request): + async def handler(request: WebSocketRequest) -> None: ''' Similar to the echo_request_handler fixture except it runs in a loop. ''' conn = await request.accept() @@ -1042,6 +1157,7 @@ async def handler(request): server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, max_message_size=100) as client: @@ -1057,10 +1173,13 @@ async def handler(request): assert client.closed.code == 1009 -async def test_server_close_client_disconnect_race(nursery, autojump_clock): +async def test_server_close_client_disconnect_race( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: """server attempts close just as client disconnects (issue #96)""" - async def handler(request: WebSocketRequest): + async def handler(request: WebSocketRequest) -> None: ws = await request.accept() ws._for_testing_peer_closed_connection = trio.Event() await ws.send_message('foo') @@ -1070,6 +1189,7 @@ async def handler(request: WebSocketRequest): server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) connection = await connect_websocket(nursery, HOST, server.port, RESOURCE, use_ssl=False) @@ -1078,7 +1198,10 @@ async def handler(request: WebSocketRequest): await trio.sleep(.1) -async def test_remote_close_local_message_race(nursery, autojump_clock): +async def test_remote_close_local_message_race( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: """as remote initiates close, local attempts message (issue #175) This exposed multiple problems in the trio-websocket API and implementation: @@ -1089,13 +1212,14 @@ async def test_remote_close_local_message_race(nursery, autojump_clock): * with wsproto >= 1.2.0, LocalProtocolError will be leaked """ - async def handler(request: WebSocketRequest): + async def handler(request: WebSocketRequest) -> None: ws = await request.accept() await ws.get_message() await ws.aclose() server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) client = await connect_websocket(nursery, HOST, server.port, RESOURCE, use_ssl=False) @@ -1106,27 +1230,28 @@ async def handler(request: WebSocketRequest): await client.send_message('bar') -async def test_message_after_local_close_race(nursery): +async def test_message_after_local_close_race(nursery: trio.Nursery) -> None: """test message send during local-initiated close handshake (issue #158)""" - async def handler(request: WebSocketRequest): + async def handler(request: WebSocketRequest) -> None: await request.accept() await trio.sleep_forever() server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) client = await connect_websocket(nursery, HOST, server.port, RESOURCE, use_ssl=False) orig_send = client._send close_sent = trio.Event() - async def _send_wrapper(event): + async def _send_wrapper(event: Event) -> None: if isinstance(event, CloseConnection): close_sent.set() return await orig_send(event) - client._send = _send_wrapper + client._send = _send_wrapper # type: ignore[method-assign] assert not client.closed nursery.start_soon(client.aclose) await close_sent.wait() @@ -1136,21 +1261,22 @@ async def _send_wrapper(event): @fail_after(DEFAULT_TEST_MAX_DURATION) -async def test_server_tcp_closed_on_close_connection_event(nursery): +async def test_server_tcp_closed_on_close_connection_event(nursery: trio.Nursery) -> None: """ensure server closes TCP immediately after receiving CloseConnection""" server_stream_closed = trio.Event() - async def _close_stream_stub(): + async def _close_stream_stub() -> None: assert not server_stream_closed.is_set() server_stream_closed.set() - async def handle_connection(request): + async def handle_connection(request: WebSocketRequest) -> None: ws = await request.accept() - ws._close_stream = _close_stream_stub + ws._close_stream = _close_stream_stub # type: ignore[method-assign] await trio.sleep_forever() server = await nursery.start( partial(serve_websocket, handle_connection, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) client = await connect_websocket(nursery, HOST, server.port, RESOURCE, use_ssl=False) # send a CloseConnection event to server but leave client connected @@ -1158,7 +1284,10 @@ async def handle_connection(request): await server_stream_closed.wait() -async def test_finalization_dropped_exception(echo_server, autojump_clock): +async def test_finalization_dropped_exception( + echo_server: WebSocketServer, + autojump_clock: trio.testing.MockClock, +) -> None: # Confirm that open_websocket finalization does not contribute to dropped # exceptions as described in https://github.com/python-trio/trio/issues/1559. with pytest.raises(ValueError): @@ -1170,7 +1299,7 @@ async def test_finalization_dropped_exception(echo_server, autojump_clock): raise ValueError -async def test_remote_close_rude(): +async def test_remote_close_rude() -> None: """ Bad ordering: 1. Remote close @@ -1180,14 +1309,17 @@ async def test_remote_close_rude(): """ client_stream, server_stream = memory_stream_pair() - async def client(): - client_conn = await wrap_client_stream(nursery, client_stream, HOST, RESOURCE) + async def client() -> None: + client_conn = await wrap_client_stream( + nursery, + client_stream, # type: ignore[arg-type] + HOST, RESOURCE) assert not client_conn.closed await client_conn.send_message('Hello from client!') with pytest.raises(ConnectionClosed): await client_conn.get_message() - async def server(): + async def server() -> None: server_request = await wrap_server_stream(nursery, server_stream) server_ws = await server_request.accept() assert not server_ws.closed @@ -1208,14 +1340,16 @@ async def server(): nursery.start_soon(client) -def test_copy_exceptions(): +def test_copy_exceptions() -> None: # test that exceptions are copy- and pickleable copy.copy(HandshakeError()) copy.copy(ConnectionTimeout()) copy.copy(DisconnectionTimeout()) - assert copy.copy(ConnectionClosed("foo")).reason == "foo" + assert copy.copy( + ConnectionClosed("foo") + ).reason == "foo" # type: ignore[comparison-overlap,arg-type] - rej_copy = copy.copy(ConnectionRejected(404, (("a", "b"),), b"c")) + rej_copy = copy.copy(ConnectionRejected(404, ((b"a", b"b"),), b"c")) assert rej_copy.status_code == 404 - assert rej_copy.headers == (("a", "b"),) + assert rej_copy.headers == ((b"a", b"b"),) assert rej_copy.body == b"c" diff --git a/trio_websocket/__init__.py b/trio_websocket/__init__.py index 82ca0ae..afa944c 100644 --- a/trio_websocket/__init__.py +++ b/trio_websocket/__init__.py @@ -1,20 +1,21 @@ +# pylint: disable=useless-import-alias from ._impl import ( - CloseReason, - ConnectionClosed, - ConnectionRejected, - ConnectionTimeout, - connect_websocket, - connect_websocket_url, - DisconnectionTimeout, - Endpoint, - HandshakeError, - open_websocket, - open_websocket_url, - WebSocketConnection, - WebSocketRequest, - WebSocketServer, - wrap_client_stream, - wrap_server_stream, - serve_websocket, + CloseReason as CloseReason, + ConnectionClosed as ConnectionClosed, + ConnectionRejected as ConnectionRejected, + ConnectionTimeout as ConnectionTimeout, + connect_websocket as connect_websocket, + connect_websocket_url as connect_websocket_url, + DisconnectionTimeout as DisconnectionTimeout, + Endpoint as Endpoint, + HandshakeError as HandshakeError, + open_websocket as open_websocket, + open_websocket_url as open_websocket_url, + WebSocketConnection as WebSocketConnection, + WebSocketRequest as WebSocketRequest, + WebSocketServer as WebSocketServer, + wrap_client_stream as wrap_client_stream, + wrap_server_stream as wrap_server_stream, + serve_websocket as serve_websocket, ) -from ._version import __version__ +from ._version import __version__ as __version__ diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 6cba00a..0b21cc7 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -1427,9 +1427,9 @@ async def _handle_message_event( msg: str | bytes # Type checker does not understand `_message_parts` if isinstance(event, BytesMessage): - msg = b''.join(self._message_parts) + msg = b''.join(cast(list[bytes], self._message_parts)) else: - msg = ''.join(self._message_parts) + msg = ''.join(cast(list[str], self._message_parts)) self._message_size = 0 self._message_parts = [] try: From a2d83f5ddea22e49706ee55aaa8ba3dbfc3d83b2 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Mon, 13 Jan 2025 01:20:39 -0600 Subject: [PATCH 24/37] Use strings for cast values --- tests/test_connection.py | 12 ++++++------ trio_websocket/_impl.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index db1bf69..c1c7948 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -161,7 +161,7 @@ async def wrapper(*args: PS.args, **kwargs: PS.kwargs) -> T | None: @attr.s(hash=False, eq=False) class MemoryListener( trio.abc.Listener[ - trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream] + "trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream]" ] ): closed: bool = attr.ib(default=False) @@ -700,7 +700,7 @@ async def handler(stream: trio.SocketStream) -> None: await stream.receive_some(max_bytes=1024) serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) raw_listeners = await nursery.start(serve_fn) - listeners = cast(list[trio.SocketListener], raw_listeners) + listeners = cast("list[trio.SocketListener]", raw_listeners) port = listeners[0].socket.getsockname()[1] with pytest.raises(ConnectionRejected) as exc_info: @@ -833,7 +833,7 @@ async def handler(stream: trio.SocketStream) -> None: assert server_ws.closed serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) raw_listeners = await nursery.start(serve_fn) - listeners = cast(list[trio.SocketListener], raw_listeners) + listeners = cast("list[trio.SocketListener]", raw_listeners) port = listeners[0].socket.getsockname()[1] async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as client: await client.send_message('Hello from client!') @@ -1032,7 +1032,7 @@ async def handler(stream: trio.SocketStream) -> None: await server_ws.send_message('Hello from client!') serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) raw_listeners = await nursery.start(serve_fn) - listeners = cast(list[trio.SocketListener], raw_listeners) + listeners = cast("list[trio.SocketListener]", raw_listeners) port = listeners[0].socket.getsockname()[1] async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as client: with pytest.raises(ConnectionClosed): @@ -1346,8 +1346,8 @@ def test_copy_exceptions() -> None: copy.copy(ConnectionTimeout()) copy.copy(DisconnectionTimeout()) assert copy.copy( - ConnectionClosed("foo") - ).reason == "foo" # type: ignore[comparison-overlap,arg-type] + ConnectionClosed("foo") # type: ignore[arg-type] + ).reason == "foo" # type: ignore[comparison-overlap] rej_copy = copy.copy(ConnectionRejected(404, ((b"a", b"b"),), b"c")) assert rej_copy.status_code == 404 diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 0b21cc7..202cb8a 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -1427,9 +1427,9 @@ async def _handle_message_event( msg: str | bytes # Type checker does not understand `_message_parts` if isinstance(event, BytesMessage): - msg = b''.join(cast(list[bytes], self._message_parts)) + msg = b''.join(cast("list[bytes]", self._message_parts)) else: - msg = ''.join(cast(list[str], self._message_parts)) + msg = ''.join(cast("list[str]", self._message_parts)) self._message_size = 0 self._message_parts = [] try: From 335dab26c20f59a7ee6c7f550db962af20ecb050 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Mon, 13 Jan 2025 01:24:59 -0600 Subject: [PATCH 25/37] Missed another type quote --- tests/test_connection.py | 2 +- trio_websocket/_impl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index c1c7948..b813b9e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -176,7 +176,7 @@ class MemoryListener( trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream] ], ] = attr.ib(factory=lambda: trio.open_memory_channel[ - trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream] + "trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream]" ](1)) accept_hook: Callable[[], Awaitable[object]] | None = attr.ib(default=None) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 202cb8a..a38a31b 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -1693,7 +1693,7 @@ def listeners(self) -> list[Endpoint | str]: listeners.append(repr(listener)) return listeners - async def run( + async def run( # type: ignore[misc] self, *, task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED, From 743e4cb4e40113f7aafe61776c0acb35583f17c4 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Mon, 13 Jan 2025 14:34:54 -0600 Subject: [PATCH 26/37] Fix some weird multiline string things --- trio_websocket/_impl.py | 48 ++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 3e31a41..c33b860 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -36,12 +36,14 @@ # pylint doesn't care about the version_info check, so need to ignore the warning from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin -_IS_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.')[:2])) < (0, 22) +_IS_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split(".")[:2])) < (0, 22) if _IS_TRIO_MULTI_ERROR: _TRIO_EXC_GROUP_TYPE = trio.MultiError # type: ignore[attr-defined] # pylint: disable=no-member else: - _TRIO_EXC_GROUP_TYPE = BaseExceptionGroup # pylint: disable=possibly-used-before-assignment + _TRIO_EXC_GROUP_TYPE = ( + BaseExceptionGroup # pylint: disable=possibly-used-before-assignment + ) CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds MESSAGE_QUEUE_SIZE = 1 @@ -85,9 +87,15 @@ def __exit__(self, ty, value, tb): return False if _IS_TRIO_MULTI_ERROR: # pragma: no cover - filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # pylint: disable=no-member - elif isinstance(value, BaseExceptionGroup): # pylint: disable=possibly-used-before-assignment - filtered_exception = value.subgroup(lambda exc: not isinstance(exc, trio.Cancelled)) + filtered_exception = trio.MultiError.filter( + _ignore_cancel, value + ) # pylint: disable=no-member + elif isinstance( + value, BaseExceptionGroup + ): # pylint: disable=possibly-used-before-assignment + filtered_exception = value.subgroup( + lambda exc: not isinstance(exc, trio.Cancelled) + ) else: filtered_exception = _ignore_cancel(value) return filtered_exception is None @@ -138,7 +146,7 @@ async def open_websocket( :raises HandshakeError: for any networking error, client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`), or server rejection (:exc:`ConnectionRejected`) during handshakes. - ''' + """ # This context manager tries very very hard not to raise an exceptiongroup # in order to be as transparent as possible for the end user. @@ -161,12 +169,16 @@ async def open_websocket( # exception in the last `finally`. If we encountered exceptions in user code # or in reader task then they will be set as the `__context__`. - async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection: try: with trio.fail_after(connect_timeout): - return await connect_websocket(nursery, host, port, - resource, use_ssl=use_ssl, subprotocols=subprotocols, + return await connect_websocket( + nursery, + host, + port, + resource, + use_ssl=use_ssl, + subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, max_message_size=max_message_size, @@ -194,7 +206,7 @@ def _raise(exc: BaseException) -> NoReturn: exc.__context__ = context del exc, context - connection: WebSocketConnection|None=None + connection: WebSocketConnection | None = None close_result: outcome.Maybe[None] | None = None user_error = None @@ -227,7 +239,7 @@ def _raise(exc: BaseException) -> NoReturn: _raise(e.exceptions[0]) # contains at most 1 non-cancelled exceptions - exception_to_raise: BaseException|None = None + exception_to_raise: BaseException | None = None for sub_exc in e.exceptions: if not isinstance(sub_exc, trio.Cancelled): if exception_to_raise is not None: @@ -257,13 +269,16 @@ def _raise(exc: BaseException) -> NoReturn: # and, if not None, `user_error.__context__` if user_error is not None: exceptions = [subexc for subexc in e.exceptions if subexc is not user_error] - eg_substr = '' + eg_substr = "" # there's technically loss of info here, with __suppress_context__=True you # still have original __context__ available, just not printed. But we delete # it completely because we can't partially suppress the group - if user_error.__context__ is not None and not user_error.__suppress_context__: + if ( + user_error.__context__ is not None + and not user_error.__suppress_context__ + ): exceptions.append(user_error.__context__) - eg_substr = ' and the context for the user exception' + eg_substr = " and the context for the user exception" eg_str = ( "Both internal and user exceptions encountered. This group contains " "the internal exception(s)" + eg_substr + "." @@ -282,7 +297,6 @@ def _raise(exc: BaseException) -> NoReturn: if close_result is not None: close_result.unwrap() - # error setting up, unwrap that exception if connection is None: result.unwrap() @@ -695,7 +709,7 @@ def __init__(self, reason): :param reason: :type reason: CloseReason - ''' + """ super().__init__(reason) self.reason = reason @@ -716,7 +730,7 @@ def __init__(self, status_code, headers, body): :param reason: :type reason: CloseReason - ''' + """ super().__init__(status_code, headers, body) #: a 3 digit HTTP status code self.status_code = status_code From 8bd6a1c2dcbb767ee81acef9813e98a98fb8ccba Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Mon, 13 Jan 2025 14:37:01 -0600 Subject: [PATCH 27/37] Fix pylint ignore comment location --- trio_websocket/_impl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index c33b860..26524e8 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -87,9 +87,9 @@ def __exit__(self, ty, value, tb): return False if _IS_TRIO_MULTI_ERROR: # pragma: no cover - filtered_exception = trio.MultiError.filter( + filtered_exception = trio.MultiError.filter( # pylint: disable=no-member _ignore_cancel, value - ) # pylint: disable=no-member + ) elif isinstance( value, BaseExceptionGroup ): # pylint: disable=possibly-used-before-assignment From 97fae8eb24eb51bf1db5f23dfa49874226f5db79 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Mon, 13 Jan 2025 14:51:32 -0600 Subject: [PATCH 28/37] Fix another pylint comment --- tests/test_connection.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 7cce20b..75bdda4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -528,8 +528,9 @@ async def handler(request): assert exc_info.value is user_error e_context = exc_info.value.__context__ assert isinstance( - e_context, BaseExceptionGroup - ) # pylint: disable=possibly-used-before-assignment + e_context, + BaseExceptionGroup, # pylint: disable=possibly-used-before-assignment + ) assert internal_error in e_context.exceptions assert user_error_context in e_context.exceptions From fd09470688130703ee70b54757e881b3be2d9d22 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Mon, 13 Jan 2025 23:00:44 -0600 Subject: [PATCH 29/37] Remove redundant command line flags --- Makefile | 2 +- trio_websocket/_impl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index f97c321..346a4c9 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ lint: $(PYTHON) -m pylint trio_websocket/ tests/ autobahn/ examples/ typecheck: - $(PYTHON) -m mypy --explicit-package-bases trio_websocket tests autobahn examples + $(PYTHON) -m mypy publish: rm -fr build dist .egg trio_websocket.egg-info diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index a38a31b..202cb8a 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -1693,7 +1693,7 @@ def listeners(self) -> list[Endpoint | str]: listeners.append(repr(listener)) return listeners - async def run( # type: ignore[misc] + async def run( self, *, task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED, From b5233c81fea57857da4c8a911e40ef82188d4140 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Mon, 13 Jan 2025 23:02:23 -0600 Subject: [PATCH 30/37] Remove now un-needed annotation for socket listeners --- trio_websocket/_impl.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 202cb8a..40c3638 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -651,9 +651,7 @@ async def serve_websocket( host=host, https_compatible=True, ) - listeners: list[trio.SSLListener[trio.SocketStream]] | list[trio.SocketListener] = ( - await open_tcp_listeners() - ) + listeners = await open_tcp_listeners() server = WebSocketServer( handler, listeners, From ce73cb1740829e04e6e2f9c9e82092ded0f4bd12 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Mon, 13 Jan 2025 23:10:32 -0600 Subject: [PATCH 31/37] Re-add type ignore because of old trio version support --- trio_websocket/_impl.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 40c3638..c8c93fd 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -1691,7 +1691,11 @@ def listeners(self) -> list[Endpoint | str]: listeners.append(repr(listener)) return listeners - async def run( + # Type ignore is because type checker does not think NoReturn is + # real for Trio 0.25.1 (current version used in requirements file as + # of writing). Not a problem for newer versions however, which is + # why we have unused-ignore as well. + async def run( # type: ignore[misc,unused-ignore] self, *, task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED, From e98f2333d8e16df2381767937f9d2da1280fe6a3 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:09:21 -0600 Subject: [PATCH 32/37] Suggestions from code review Co-authored-by: jakkdl --- autobahn/client.py | 7 ++++-- mypy.ini | 28 ------------------------ pyproject.toml | 13 +++++++++++ setup.py | 1 + tests/test_connection.py | 47 ++++++++++++++++++++-------------------- trio_websocket/_impl.py | 29 ++++++++++++++++--------- 6 files changed, 61 insertions(+), 64 deletions(-) delete mode 100644 mypy.ini create mode 100644 pyproject.toml diff --git a/autobahn/client.py b/autobahn/client.py index 1d8a082..1cfb9d2 100644 --- a/autobahn/client.py +++ b/autobahn/client.py @@ -26,7 +26,7 @@ async def get_case_count(url: str) -> int: return int(case_count) -async def get_case_info(url: str, case: str) -> Any: +async def get_case_info(url: str, case: str) -> object: url = f'{url}/getCaseInfo?case={case}' async with open_websocket_url(url) as conn: return json.loads(await conn.get_message()) @@ -63,7 +63,10 @@ async def run_tests(args: argparse.Namespace) -> None: test_cases = list(range(1, case_count + 1)) exception_cases = [] for case in test_cases: - case_id = (await get_case_info(args.url, case))['id'] + result = await get_case_info(args.url, case) + assert isinstance(result, dict) + case_id = result['id'] + assert isinstance(case_id, int) if case_count: logger.info("Running test case %s (%d of %d)", case_id, case, case_count) else: diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 09cd26d..0000000 --- a/mypy.ini +++ /dev/null @@ -1,28 +0,0 @@ -[mypy] -explicit_package_bases = true -files = trio_websocket,tests,autobahn,examples -show_column_numbers = true -show_error_codes = true -show_traceback = true -warn_redundant_casts = true -warn_unused_configs = true - -[mypy-trio_websocket] -check_untyped_defs = true -disallow_any_decorated = true -disallow_any_generics = true -disallow_any_unimported = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -ignore_missing_imports = true -local_partial_types = true -no_implicit_optional = true -no_implicit_reexport = true -strict = true -strict_equality = true -warn_return_any = true -warn_unreachable = true -warn_unused_ignores = true diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..95d5ff9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,13 @@ +[tool.mypy] +explicit_package_bases = true +files = ["trio_websocket", "tests", "autobahn", "examples"] +show_column_numbers = true +show_error_codes = true +show_traceback = true +disallow_any_decorated = true +disallow_any_unimported = true +ignore_missing_imports = true +local_partial_types = true +no_implicit_optional = true +strict = true +warn_unreachable = true diff --git a/setup.py b/setup.py index b38bb70..46c6506 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ 'Programming Language :: Python :: 3.12', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy', + 'Typing :: Typed', ], python_requires=">=3.8", keywords='websocket client server trio', diff --git a/tests/test_connection.py b/tests/test_connection.py index b813b9e..74fafd1 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -39,6 +39,7 @@ from functools import partial, wraps from typing import TYPE_CHECKING, TypeVar, cast from unittest.mock import Mock, patch +from importlib.metadata import version import attr import pytest @@ -89,9 +90,11 @@ from collections.abc import Awaitable, Callable from wsproto.events import Event - from typing_extensions import ParamSpec + from typing_extensions import ParamSpec, TypeAlias PS = ParamSpec("PS") + StapledMemoryStream: TypeAlias = trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream] + WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.'))) HOST = '127.0.0.1' @@ -116,6 +119,8 @@ async def echo_server(nursery: trio.Nursery) -> AsyncGenerator[WebSocketServer, serve_fn = partial(serve_websocket, echo_request_handler, HOST, 0, ssl_context=None) server = await nursery.start(serve_fn) + # Cast needed because currently `nursery.start` has typing issues + # blocked by https://github.com/python/mypy/pull/17512 yield cast(WebSocketServer, server) @@ -147,37 +152,28 @@ def __init__(self, seconds: int) -> None: self._seconds = seconds def __call__(self, fn: Callable[PS, Awaitable[T]]) -> Callable[PS, Awaitable[T | None]]: + # Type of decorated function contains type `Any` @wraps(fn) - async def wrapper(*args: PS.args, **kwargs: PS.kwargs) -> T | None: - result: T | None = None + async def wrapper( # type: ignore[misc] + *args: PS.args, + **kwargs: PS.kwargs, + ) -> T: with trio.move_on_after(self._seconds) as cancel_scope: - result = await fn(*args, **kwargs) + return await fn(*args, **kwargs) if cancel_scope.cancelled_caught: pytest.fail(f'Test runtime exceeded the maximum {self._seconds} seconds') - return result + raise AssertionError("Should be unreachable") return wrapper @attr.s(hash=False, eq=False) -class MemoryListener( - trio.abc.Listener[ - "trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream]" - ] -): +class MemoryListener(trio.abc.Listener["StapledMemoryStream"]): closed: bool = attr.ib(default=False) - accepted_streams: list[ - trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream] - ] = attr.ib(factory=list) + accepted_streams: list[StapledMemoryStream] = attr.ib(factory=list) queued_streams: tuple[ - trio.MemorySendChannel[ - trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream] - ], - trio.MemoryReceiveChannel[ - trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream] - ], - ] = attr.ib(factory=lambda: trio.open_memory_channel[ - "trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream]" - ](1)) + trio.MemorySendChannel[StapledMemoryStream], + trio.MemoryReceiveChannel[StapledMemoryStream], + ] = attr.ib(factory=lambda: trio.open_memory_channel["StapledMemoryStream"](1)) accept_hook: Callable[[], Awaitable[object]] | None = attr.ib(default=None) async def connect(self) -> trio.StapledStream[ @@ -385,8 +381,11 @@ async def test_ascii_encoded_path_is_ok(echo_server: WebSocketServer) -> None: assert conn.path == RESOURCE + '/' + path +# Type ignore because @patch contains `Any` @patch('trio_websocket._impl.open_websocket') -def test_client_open_url_options(open_websocket_mock: Mock) -> None: +def test_client_open_url_options( # type: ignore[misc] + open_websocket_mock: Mock, +) -> None: """open_websocket_url() must pass its options on to open_websocket()""" port = 1234 url = f'ws://{HOST}:{port}{RESOURCE}' @@ -618,7 +617,7 @@ async def handler(request: WebSocketRequest) -> None: assert exc_info.value.__context__ is user_cancelled_context def _trio_default_non_strict_exception_groups() -> bool: - version = trio.__version__ # type: ignore[attr-defined] + version = version("trio") assert re.match(r'^0\.\d\d\.', version), "unexpected trio versioning scheme" return int(version[2:4]) < 25 diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index c8c93fd..b7c1ea4 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -12,6 +12,7 @@ import struct import urllib.parse from typing import Any, List, NoReturn, Optional, Union, TypeVar, TYPE_CHECKING, Generic, cast +from importlib.metadata import version import outcome import trio @@ -38,22 +39,21 @@ if TYPE_CHECKING: from types import TracebackType + from typing_extensions import Final from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Coroutine, Sequence -_IS_TRIO_MULTI_ERROR = tuple( - map(int, trio.__version__.split(".")[:2]) # type: ignore[attr-defined] -) < (0, 22) +_IS_TRIO_MULTI_ERROR: Final = tuple(map(int, version("trio").split(".")[:2])) < (0, 22) if _IS_TRIO_MULTI_ERROR: _TRIO_EXC_GROUP_TYPE = trio.MultiError # type: ignore[attr-defined] # pylint: disable=no-member else: _TRIO_EXC_GROUP_TYPE = BaseExceptionGroup # pylint: disable=possibly-used-before-assignment -CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds -MESSAGE_QUEUE_SIZE = 1 -MAX_MESSAGE_SIZE = 2 ** 20 # 1 MiB -RECEIVE_BYTES = 4 * 2 ** 10 # 4 KiB -logger = logging.getLogger('trio-websocket') +CONN_TIMEOUT: Final = 60 # default connect & disconnect timeout, in seconds +MESSAGE_QUEUE_SIZE: Final = 1 +MAX_MESSAGE_SIZE: Final = 2 ** 20 # 1 MiB +RECEIVE_BYTES: Final = 4 * 2 ** 10 # 4 KiB +logger: Final = logging.getLogger('trio-websocket') T = TypeVar("T") E = TypeVar("E", bound=BaseException) @@ -770,11 +770,16 @@ def __repr__(self) -> str: f'' +NULL: Final = object() + + class Future(Generic[T]): ''' Represents a value that will be available in the future. ''' def __init__(self) -> None: ''' Constructor. ''' - self._value: T | None = None + # We do some type shenanigins + # Would do `T | Literal[NULL]` but that's not right apparently. + self._value: T = cast(T, NULL) self._value_event = trio.Event() def set_value(self, value: T) -> None: @@ -793,7 +798,8 @@ async def wait_value(self) -> T: :returns: The value set by ``set_value()``. ''' await self._value_event.wait() - return cast(T, self._value) + assert self._value is not NULL + return self._value class WebSocketRequest: @@ -1509,6 +1515,9 @@ async def _reader_task(self) -> None: handler = handlers[event_type] logger.debug('%s received event: %s', self, event_type) + # Type checkers don't understand looking up type in handlers. + # If we wanted to fix, best I can figure is we'd need a huge + # if-else or case block for every type individually. await handler(event) # type: ignore[operator] except KeyError: logger.warning('%s received unknown event type: "%s"', self, From 96e437f853250a32eb8d3b31496e68a1546735e8 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:15:44 -0600 Subject: [PATCH 33/37] Fix lint issues --- autobahn/client.py | 1 - tests/test_connection.py | 11 +++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/autobahn/client.py b/autobahn/client.py index 1cfb9d2..5c94b56 100644 --- a/autobahn/client.py +++ b/autobahn/client.py @@ -6,7 +6,6 @@ import json import logging import sys -from typing import Any import trio from trio_websocket import open_websocket_url, ConnectionClosed diff --git a/tests/test_connection.py b/tests/test_connection.py index 74fafd1..dab2839 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -93,7 +93,10 @@ from typing_extensions import ParamSpec, TypeAlias PS = ParamSpec("PS") - StapledMemoryStream: TypeAlias = trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream] + StapledMemoryStream: TypeAlias = trio.StapledStream[ + trio.testing.MemorySendStream, + trio.testing.MemoryReceiveStream, + ] WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.'))) @@ -617,9 +620,9 @@ async def handler(request: WebSocketRequest) -> None: assert exc_info.value.__context__ is user_cancelled_context def _trio_default_non_strict_exception_groups() -> bool: - version = version("trio") - assert re.match(r'^0\.\d\d\.', version), "unexpected trio versioning scheme" - return int(version[2:4]) < 25 + trio_version = version("trio") + assert re.match(r'^0\.\d\d\.', trio_version), "unexpected trio versioning scheme" + return int(trio_version[2:4]) < 25 @fail_after(1) async def test_handshake_exception_before_accept() -> None: From 35a623513232242de1073cafd43aa5b9edd6a6f1 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 16 Jan 2025 13:06:07 -0600 Subject: [PATCH 34/37] Avoid a few type errors with fake socket listener Co-authored-by: A5rocks --- tests/test_connection.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index dab2839..45c1268 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -296,11 +296,15 @@ async def test_serve_with_zero_listeners() -> None: WebSocketServer(echo_request_handler, []) +def memory_listener() -> trio.SocketListener: + return MemoryListener() # type: ignore[return-value] + + async def test_serve_non_tcp_listener(nursery: trio.Nursery) -> None: - listeners = [MemoryListener()] + listeners = [memory_listener()] server = WebSocketServer( echo_request_handler, - listeners, # type: ignore[arg-type] + listeners, ) await nursery.start(server.run) assert len(server.listeners) == 1 @@ -313,11 +317,11 @@ async def test_serve_non_tcp_listener(nursery: trio.Nursery) -> None: async def test_serve_multiple_listeners(nursery: trio.Nursery) -> None: listener1 = (await trio.open_tcp_listeners(0, host=HOST))[0] - listener2 = MemoryListener() + listener2 = memory_listener() server = WebSocketServer( echo_request_handler, [ listener1, - listener2, # type: ignore[list-item] + listener2, ] ) await nursery.start(server.run) From 4ec83e3d2023feefccbbf8b5a8908d502acb93f2 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 30 Jan 2025 10:44:40 -0600 Subject: [PATCH 35/37] Fix merge issue --- setup.py | 36 +----------------------------------- 1 file changed, 1 insertion(+), 35 deletions(-) diff --git a/setup.py b/setup.py index 8ab161b..2166c5e 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,6 @@ author_email="mehaase@gmail.com", classifiers=[ # See https://pypi.org/classifiers/ -<<<<<<< HEAD "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Topic :: Software Development :: Libraries", @@ -36,45 +35,12 @@ "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", -||||||| e7706f4 - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Libraries', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', -======= - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Libraries', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', - 'Typing :: Typed', ->>>>>>> root/master + "Typing :: Typed", ], python_requires=">=3.8", -<<<<<<< HEAD keywords="websocket client server trio", packages=find_packages(exclude=["docs", "examples", "tests"]), -||||||| e7706f4 - keywords='websocket client server trio', - packages=find_packages(exclude=['docs', 'examples', 'tests']), -======= - keywords='websocket client server trio', - packages=find_packages(exclude=['docs', 'examples', 'tests']), package_data={"trio-websocket": ["py.typed"]}, ->>>>>>> root/master install_requires=[ 'exceptiongroup; python_version<"3.11"', "trio>=0.11", From d9081fd25e8ada0dfad40bb07df9b8fcaf99ec98 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 30 Jan 2025 10:49:11 -0600 Subject: [PATCH 36/37] Remove un-needed string concat --- autobahn/client.py | 6 ++---- autobahn/server.py | 4 +--- examples/server.py | 5 ++--- tests/test_connection.py | 18 ++++++++++++++---- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/autobahn/client.py b/autobahn/client.py index b0006e9..027bb4e 100644 --- a/autobahn/client.py +++ b/autobahn/client.py @@ -94,16 +94,14 @@ async def run_tests(args: argparse.Namespace) -> None: def parse_args() -> argparse.Namespace: """Parse command line arguments.""" - parser = argparse.ArgumentParser( - description="Autobahn client for" " trio-websocket" - ) + parser = argparse.ArgumentParser(description="Autobahn client for trio-websocket") parser.add_argument("url", help="WebSocket URL for server") # TODO: accept case ID's rather than indices parser.add_argument( "debug_cases", type=int, nargs="*", - help="Run" " individual test cases with debug logging (optional)", + help="Run individual test cases with debug logging (optional)", ) return parser.parse_args() diff --git a/autobahn/server.py b/autobahn/server.py index c573253..223248d 100644 --- a/autobahn/server.py +++ b/autobahn/server.py @@ -52,9 +52,7 @@ async def handler(request: WebSocketRequest) -> None: def parse_args() -> argparse.Namespace: """Parse command line arguments.""" - parser = argparse.ArgumentParser( - description="Autobahn server for" " trio-websocket" - ) + parser = argparse.ArgumentParser(description="Autobahn server for trio-websocket") parser.add_argument( "-d", "--debug", action="store_true", help="WebSocket URL for server" ) diff --git a/examples/server.py b/examples/server.py index 8e3a9cc..1af40f4 100644 --- a/examples/server.py +++ b/examples/server.py @@ -29,7 +29,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--ssl", action="store_true", help="Use SSL") parser.add_argument( "host", - help="Host interface to bind. If omitted, " "then bind all interfaces.", + help="Host interface to bind. If omitted, then bind all interfaces.", nargs="?", ) parser.add_argument("port", type=int, help="Port to bind.") @@ -45,8 +45,7 @@ async def main(args: argparse.Namespace) -> None: ssl_context.load_cert_chain(here / "fake.server.pem") except FileNotFoundError: logging.error( - 'Did not find file "fake.server.pem". You need to run' - " generate-cert.py" + 'Did not find file "fake.server.pem". You need to run generate-cert.py' ) else: ssl_context = None diff --git a/tests/test_connection.py b/tests/test_connection.py index 73dac90..c018259 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -187,7 +187,9 @@ class MemoryListener(trio.abc.Listener["StapledMemoryStream"]): ] = attr.ib(factory=lambda: trio.open_memory_channel["StapledMemoryStream"](1)) accept_hook: Callable[[], Awaitable[object]] | None = attr.ib(default=None) - async def connect(self) -> trio.StapledStream[ + async def connect( + self, + ) -> trio.StapledStream[ trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream, ]: @@ -196,7 +198,9 @@ async def connect(self) -> trio.StapledStream[ await self.queued_streams[0].send(server) return client - async def accept(self) -> trio.StapledStream[ + async def accept( + self, + ) -> trio.StapledStream[ trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream, ]: @@ -878,7 +882,10 @@ async def test_wrap_client_stream(nursery: trio.Nursery) -> None: await nursery.start(server.run) stream = await listener.connect() conn = await wrap_client_stream( - nursery, stream, HOST, RESOURCE # type: ignore[arg-type] + nursery, + stream, # type: ignore[arg-type] + HOST, + RESOURCE, ) async with conn: assert not conn.closed @@ -1427,7 +1434,10 @@ async def test_remote_close_rude() -> None: async def client() -> None: client_conn = await wrap_client_stream( - nursery, client_stream, HOST, RESOURCE # type: ignore[arg-type] + nursery, + client_stream, # type: ignore[arg-type] + HOST, + RESOURCE, ) assert not client_conn.closed await client_conn.send_message("Hello from client!") From 2f9c10bbb1fb76474d77df30644da76e742e48f2 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 30 Jan 2025 10:50:57 -0600 Subject: [PATCH 37/37] Fix pylint issue --- tests/test_connection.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index c018259..0583523 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -630,8 +630,9 @@ async def handler(request: WebSocketRequest) -> None: assert exc_info.value is user_error e_context = exc_info.value.__context__ assert isinstance( - e_context, BaseExceptionGroup - ) # pylint: disable=possibly-used-before-assignment + e_context, + BaseExceptionGroup, # pylint: disable=possibly-used-before-assignment + ) assert internal_error in e_context.exceptions assert user_error_context in e_context.exceptions