diff --git a/cheroot/makefile.py b/cheroot/makefile.py index f5780a1ede..61faca3ed5 100644 --- a/cheroot/makefile.py +++ b/cheroot/makefile.py @@ -3,6 +3,9 @@ # prefer slower Python-based io module import _pyio as io import socket +import time + +from OpenSSL import SSL # Write only 16K at a time to sockets @@ -31,7 +34,24 @@ def _flush_unlocked(self): # so perhaps we should conditionally wrap this for perf? n = self.raw.write(bytes(self._write_buf)) except io.BlockingIOError as e: + # some data may have been written + # we need to remove that from the buffer before retryings n = e.characters_written + except ( + SSL.WantReadError, + SSL.WantWriteError, + SSL.WantX509LookupError, + ): + # these errors require retries with the same data + # regardless of whether data has already been written + continue + except OSError: + # This catches errors like EBADF (Bad File Descriptor) + # or EPIPE (Broken pipe), which indicate the underlying + # socket is already closed or invalid. + # Since this happens in __del__, we silently stop flushing. + self._write_buf.clear() + return # Exit the function del self._write_buf[:n] @@ -45,9 +65,22 @@ def __init__(self, sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE): def read(self, *args, **kwargs): """Capture bytes read.""" - val = super().read(*args, **kwargs) - self.bytes_read += len(val) - return val + MAX_ATTEMPTS = 10 + last_error = None + for _ in range(MAX_ATTEMPTS): + try: + val = super().read(*args, **kwargs) + except (SSL.WantReadError, SSL.WantWriteError) as ssl_want_error: + last_error = ssl_want_error + time.sleep(0.1) + else: + self.bytes_read += len(val) + return val + + # If we get here, all attempts failed + raise TimeoutError( + 'Max retries exceeded while waiting for data.', + ) from last_error def has_data(self): """Return true if there is buffered data to read.""" diff --git a/cheroot/test/test_ssl.py b/cheroot/test/test_ssl.py index 8fd597e33a..6efe181390 100644 --- a/cheroot/test/test_ssl.py +++ b/cheroot/test/test_ssl.py @@ -1,5 +1,6 @@ """Tests for TLS support.""" +import errno import functools import http.client import json @@ -17,6 +18,8 @@ import requests import trustme +from cheroot.makefile import BufferedWriter + from .._compat import ( IS_ABOVE_OPENSSL10, IS_ABOVE_OPENSSL31, @@ -625,6 +628,186 @@ def test_ssl_env( # noqa: C901 # FIXME ) +@pytest.fixture +def mock_raw_open_socket(mocker): + """Return a mocked raw socket prepared for writing (closed=False).""" + # This fixture sets the state on the injected object + mock_raw = mocker.Mock(name='mock_raw_socket') + mock_raw.closed = False + return mock_raw + + +@pytest.fixture +def ssl_writer(mock_raw_open_socket): + """Return a BufferedWriter instance with a mocked raw socket.""" + return BufferedWriter(mock_raw_open_socket) + + +def test_want_write_error_retry(ssl_writer, mock_raw_open_socket): + """Test that WantWriteError causes retry with same data.""" + test_data = b'hello world' + + # set up mock socket so that when its write() method is called, + # we get WantWriteError first, then success on the second call + # indicated by returning the number of bytes written + mock_raw_open_socket.write.side_effect = [ + OpenSSL.SSL.WantWriteError(), + len(test_data), + ] + + bytes_written = ssl_writer.write(test_data) + assert bytes_written == len(test_data) + + # Assert against the injected mock object + assert mock_raw_open_socket.write.call_count == 2 + + +def test_want_read_error_retry(ssl_writer, mock_raw_open_socket): + """Test that WantReadError causes retry with same data.""" + test_data = b'test data' + + # set up mock socket so that when its write() method is called, + # we get WantReadError first, then success on the second call + # indicated by returning the number of bytes written + mock_raw_open_socket.write.side_effect = [ + OpenSSL.SSL.WantReadError(), + len(test_data), + ] + + bytes_written = ssl_writer.write(test_data) + assert bytes_written == len(test_data) + + +@pytest.fixture( + params=('builtin', 'pyopenssl'), +) +def adapter_type(request): + """Fixture that yields the name of the SSL adapter.""" + return request.param + + +@pytest.fixture +def create_side_effects_factory(adapter_type): + """ + Fixture that returns a factory function to create the side effect list. + + The factory function returns a list of two items: + 1. An error to be raised on the first call + 2. The length of data written on the second call + + It returns a function that takes one argument, + allowing the data length to be injected from the test function. + """ + if adapter_type == 'pyopenssl': + failure_error = OpenSSL.SSL.WantWriteError() + else: # adapter_type == 'builtin' + failure_error = BlockingIOError( + errno.EWOULDBLOCK, + 'Resource temporarily unavailable', + ) + failure_error.characters_written = 0 + + def generate_side_effects(data_length): # noqa: WPS430 + """Return the list: [failure_error, data_length].""" + return [ + failure_error, + data_length, # This uses the length provided by the test + ] + + # Return the inner function + return generate_side_effects + + +@pytest.fixture +def ssl_writer_integration( + mocker, + mock_raw_open_socket, + adapter_type, + tls_certificate_chain_pem_path, + tls_certificate_private_key_pem_path, +): + """ + Set up mock SSL writer for integration test. + + Mocks the lowest-level write/send method to simulate a + WantWriteError for the PYOPENSSL adapter, and a + BlockingIOError for the BUILTIN adapter. + """ + # Set up SSL adapter + tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) + tls_adapter = tls_adapter_cls( + tls_certificate_chain_pem_path, + tls_certificate_private_key_pem_path, + ) + + if adapter_type == 'pyopenssl': + # --- PYOPENSSL SETUP + # Ensure context is initialized, as required by an OpenSSL Connection + tls_adapter.context = tls_adapter.get_context() + # need to mock a dummy fd on the mocked raw socket + mock_raw_open_socket.fileno.return_value = 1 + + # Create an OpenSSL.SSL.Connection object using the mocked raw socket + ssl_conn = OpenSSL.SSL.Connection( + tls_adapter.context, + mock_raw_open_socket, + ) + ssl_conn.set_connect_state() + ssl_conn.closed = False + + # we need to mock a write method on the mocked raw socket + raw_io_object = ssl_conn + raw_io_object.write = mocker.Mock(name='ssl_conn_write_mock') + else: + # adapter_type == 'builtin' + # --- BUILTIN ADAPTER SETUP (Requires different mocking) --- + # we need to mock the adapter's own write and writable methods + raw_io_object = tls_adapter + raw_io_object.writable = mocker.Mock(return_value=True) + raw_io_object.write = mocker.Mock( + name='builtin_adapter_write', + ) + raw_io_object.closed = False + + # return mock assertion target + return raw_io_object + + +def test_want_write_error_integration( + ssl_writer_integration, + create_side_effects_factory, +): + """Integration test for SSL writer handling of WantWriteError. + + This test gets called twice, once for each adapter type. + The fixture ssl_writer_integration sets up the mock write method. + The fixture create_side_effects_factory creates the side effect list + with the data length injected from this test function. + """ + raw_io_object = ssl_writer_integration + test_data = b'integration test data' + successful_write_length = len(test_data) + + # Call side effects factory function to create + # a two step list for the mock write method. + # First call raises error, second call returns length. + # We have to inject the length because the factory + # is created in a fixture that doesn't know the test data. + side_effects = create_side_effects_factory(successful_write_length) + raw_io_object.write.side_effect = side_effects + + writer = BufferedWriter(raw_io_object) + + # write data and then flush + # with the way the mock_write is set up this should fail once, + # and then succeed on the retry. + bytes_written = writer.write(test_data) + writer.flush() + + assert bytes_written == successful_write_length + assert raw_io_object.write.call_count == 2 + + @pytest.mark.parametrize( 'ip_addr', ( diff --git a/docs/changelog-fragments.d/764.bugfix.rst b/docs/changelog-fragments.d/764.bugfix.rst new file mode 100644 index 0000000000..053da889ea --- /dev/null +++ b/docs/changelog-fragments.d/764.bugfix.rst @@ -0,0 +1,7 @@ +Added handling for WantWriteError and WantReadError in BufferedWriter +and StreamReader to enable retries. This addresses long standing issues +discussed in #245. The reliability of the fix relies on using pyOpenSSL +v25.2.0 or greater, as earlier versions have known bugs that affect +the retry logic. + +-- by :user:`julianz-` diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index f9020b9ce6..35e085a903 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -43,6 +43,7 @@ positionally pre preconfigure py +pyOpenSSL pytest pythonic readonly