Skip to content

Commit 9cb77f9

Browse files
committed
Add a test for WANT_READ during sendall()
1 parent 124a013 commit 9cb77f9

File tree

1 file changed

+123
-3
lines changed

1 file changed

+123
-3
lines changed

tests/test_ssl.py

+123-3
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def loopback_server_factory(socket, version=SSLv23_METHOD):
308308
return server
309309

310310

311-
def loopback(server_factory=None, client_factory=None):
311+
def loopback(server_factory=None, client_factory=None, blocking=True):
312312
"""
313313
Create a connected socket pair and force two connected SSL sockets
314314
to talk to each other via memory BIOs.
@@ -324,8 +324,8 @@ def loopback(server_factory=None, client_factory=None):
324324

325325
handshake(client, server)
326326

327-
server.setblocking(True)
328-
client.setblocking(True)
327+
server.setblocking(blocking)
328+
client.setblocking(blocking)
329329
return server, client
330330

331331

@@ -3131,11 +3131,131 @@ def test_memoryview_really_doesnt_overfill(self):
31313131
self._doesnt_overfill_test(_make_memoryview)
31323132

31333133

3134+
@pytest.fixture
3135+
def nonblocking_tls_connections_pair():
3136+
"""Return a non-blocking TLS loopback connections pair."""
3137+
return loopback(blocking=False)
3138+
3139+
3140+
@pytest.fixture
3141+
def nonblocking_tls_server_connection(nonblocking_tls_connections_pair):
3142+
"""Return a non-blocking TLS server socket connected to loopback."""
3143+
return nonblocking_tls_connections_pair[0]
3144+
3145+
3146+
@pytest.fixture
3147+
def nonblocking_tls_client_connection(nonblocking_tls_connections_pair):
3148+
"""Return a non-blocking TLS client socket connected to loopback."""
3149+
return nonblocking_tls_connections_pair[1]
3150+
3151+
31343152
class TestConnectionSendall(object):
31353153
"""
31363154
Tests for `Connection.sendall`.
31373155
"""
31383156

3157+
def test_want_write(
3158+
self,
3159+
monkeypatch,
3160+
nonblocking_tls_server_connection,
3161+
nonblocking_tls_client_connection,
3162+
):
3163+
msg = b"x"
3164+
garbage_size = 1024 * 1024 * 64
3165+
large_payload = b"p" * garbage_size * 2
3166+
payload_size = len(large_payload)
3167+
3168+
sent_garbage_size = 0
3169+
try:
3170+
sent_garbage_size += nonblocking_tls_client_connection.send(
3171+
msg * garbage_size,
3172+
)
3173+
except WantWriteError:
3174+
pass
3175+
for i in range(garbage_size):
3176+
try:
3177+
sent_garbage_size += nonblocking_tls_client_connection.send(
3178+
msg,
3179+
)
3180+
except WantWriteError:
3181+
break
3182+
else:
3183+
pytest.fail(
3184+
"Failed to fill socket buffer, cannot test "
3185+
"'want write' in `sendall()`"
3186+
)
3187+
garbage_payload = sent_garbage_size * msg
3188+
3189+
3190+
def consume_garbage(conn):
3191+
assert patched_ssl_write.want_write_counter >= 1
3192+
assert not consume_garbage.garbage_consumed
3193+
3194+
while len(consume_garbage.consumed) < sent_garbage_size:
3195+
try:
3196+
consume_garbage.consumed += conn.recv(
3197+
sent_garbage_size - len(consume_garbage.consumed),
3198+
)
3199+
except WantReadError:
3200+
pass
3201+
3202+
assert consume_garbage.consumed == garbage_payload
3203+
3204+
consume_garbage.garbage_consumed = True
3205+
3206+
consume_garbage.garbage_consumed = False
3207+
consume_garbage.consumed = b""
3208+
3209+
def consume_payload(conn):
3210+
try:
3211+
consume_payload.consumed += conn.recv(payload_size)
3212+
except WantReadError:
3213+
pass
3214+
consume_payload.consumed = b""
3215+
3216+
original_ssl_write = _lib.SSL_write
3217+
def patched_ssl_write(ctx, data, size):
3218+
write_result = original_ssl_write(ctx, data, size)
3219+
try:
3220+
nonblocking_tls_client_connection._raise_ssl_error(
3221+
ctx, write_result,
3222+
)
3223+
except WantWriteError:
3224+
patched_ssl_write.want_write_counter += 1
3225+
consume_data_on_server = (
3226+
consume_payload if consume_garbage.garbage_consumed
3227+
else consume_garbage
3228+
)
3229+
3230+
consume_data_on_server(nonblocking_tls_server_connection)
3231+
# NOTE: We don't re-raise it as the calling code will do
3232+
# NOTE: the same after the call.
3233+
return write_result
3234+
3235+
patched_ssl_write.want_write_counter = 0
3236+
3237+
# NOTE: Make the client think it needs a handshake so that it'll
3238+
# NOTE: attempt to `do_handshake()` on the next `SSL_write()`
3239+
# NOTE: that originates from `sendall()`:
3240+
nonblocking_tls_client_connection.set_connect_state()
3241+
try:
3242+
nonblocking_tls_client_connection.do_handshake()
3243+
except WantWriteError:
3244+
assert True # Sanity check
3245+
except:
3246+
assert False # This should never happen (see the note above)
3247+
3248+
with monkeypatch.context() as mp_ctx:
3249+
mp_ctx.setattr(_lib, "SSL_write", patched_ssl_write)
3250+
nonblocking_tls_client_connection.sendall(large_payload)
3251+
3252+
assert consume_garbage.garbage_consumed
3253+
3254+
# NOTE: Read the leftover data from the very last `SSL_write()`
3255+
consume_payload(nonblocking_tls_server_connection)
3256+
3257+
assert consume_payload.consumed == large_payload
3258+
31393259
def test_wrong_args(self):
31403260
"""
31413261
When called with arguments other than a string argument for its first

0 commit comments

Comments
 (0)