@@ -321,7 +321,7 @@ def loopback_server_factory(socket, version=SSLv23_METHOD):
321
321
return server
322
322
323
323
324
- def loopback (server_factory = None , client_factory = None ):
324
+ def loopback (server_factory = None , client_factory = None , blocking = True ):
325
325
"""
326
326
Create a connected socket pair and force two connected SSL sockets
327
327
to talk to each other via memory BIOs.
@@ -337,8 +337,8 @@ def loopback(server_factory=None, client_factory=None):
337
337
338
338
handshake (client , server )
339
339
340
- server .setblocking (True )
341
- client .setblocking (True )
340
+ server .setblocking (blocking )
341
+ client .setblocking (blocking )
342
342
return server , client
343
343
344
344
@@ -3292,11 +3292,134 @@ def test_memoryview_really_doesnt_overfill(self):
3292
3292
self ._doesnt_overfill_test (_make_memoryview )
3293
3293
3294
3294
3295
+ @pytest .fixture
3296
+ def nonblocking_tls_connections_pair ():
3297
+ """Return a non-blocking TLS loopback connections pair."""
3298
+ return loopback (blocking = False )
3299
+
3300
+
3301
+ @pytest .fixture
3302
+ def nonblocking_tls_server_connection (nonblocking_tls_connections_pair ):
3303
+ """Return a non-blocking TLS server socket connected to loopback."""
3304
+ return nonblocking_tls_connections_pair [0 ]
3305
+
3306
+
3307
+ @pytest .fixture
3308
+ def nonblocking_tls_client_connection (nonblocking_tls_connections_pair ):
3309
+ """Return a non-blocking TLS client socket connected to loopback."""
3310
+ return nonblocking_tls_connections_pair [1 ]
3311
+
3312
+
3295
3313
class TestConnectionSendall :
3296
3314
"""
3297
3315
Tests for `Connection.sendall`.
3298
3316
"""
3299
3317
3318
+ def test_want_write (
3319
+ self ,
3320
+ monkeypatch ,
3321
+ nonblocking_tls_server_connection ,
3322
+ nonblocking_tls_client_connection ,
3323
+ ):
3324
+ msg = b"x"
3325
+ garbage_size = 1024 * 1024 * 64
3326
+ large_payload = b"p" * garbage_size * 2
3327
+ payload_size = len (large_payload )
3328
+
3329
+ sent_garbage_size = 0
3330
+ try :
3331
+ sent_garbage_size += nonblocking_tls_client_connection .send (
3332
+ msg * garbage_size ,
3333
+ )
3334
+ except WantWriteError :
3335
+ pass
3336
+ for i in range (garbage_size ):
3337
+ try :
3338
+ sent_garbage_size += nonblocking_tls_client_connection .send (
3339
+ msg ,
3340
+ )
3341
+ except WantWriteError :
3342
+ break
3343
+ else :
3344
+ pytest .fail (
3345
+ "Failed to fill socket buffer, cannot test "
3346
+ "'want write' in `sendall()`"
3347
+ )
3348
+ garbage_payload = sent_garbage_size * msg
3349
+
3350
+ def consume_garbage (conn ):
3351
+ assert patched_ssl_write .want_write_counter >= 1
3352
+ assert not consume_garbage .garbage_consumed
3353
+
3354
+ while len (consume_garbage .consumed ) < sent_garbage_size :
3355
+ try :
3356
+ consume_garbage .consumed += conn .recv (
3357
+ sent_garbage_size - len (consume_garbage .consumed ),
3358
+ )
3359
+ except WantReadError :
3360
+ pass
3361
+
3362
+ assert consume_garbage .consumed == garbage_payload
3363
+
3364
+ consume_garbage .garbage_consumed = True
3365
+
3366
+ consume_garbage .garbage_consumed = False
3367
+ consume_garbage .consumed = b""
3368
+
3369
+ def consume_payload (conn ):
3370
+ try :
3371
+ consume_payload .consumed += conn .recv (payload_size )
3372
+ except WantReadError :
3373
+ pass
3374
+
3375
+ consume_payload .consumed = b""
3376
+
3377
+ original_ssl_write = _lib .SSL_write
3378
+
3379
+ def patched_ssl_write (ctx , data , size ):
3380
+ write_result = original_ssl_write (ctx , data , size )
3381
+ try :
3382
+ nonblocking_tls_client_connection ._raise_ssl_error (
3383
+ ctx ,
3384
+ write_result ,
3385
+ )
3386
+ except WantWriteError :
3387
+ patched_ssl_write .want_write_counter += 1
3388
+ consume_data_on_server = (
3389
+ consume_payload
3390
+ if consume_garbage .garbage_consumed
3391
+ else consume_garbage
3392
+ )
3393
+
3394
+ consume_data_on_server (nonblocking_tls_server_connection )
3395
+ # NOTE: We don't re-raise it as the calling code will do
3396
+ # NOTE: the same after the call.
3397
+ return write_result
3398
+
3399
+ patched_ssl_write .want_write_counter = 0
3400
+
3401
+ # NOTE: Make the client think it needs a handshake so that it'll
3402
+ # NOTE: attempt to `do_handshake()` on the next `SSL_write()`
3403
+ # NOTE: that originates from `sendall()`:
3404
+ nonblocking_tls_client_connection .set_connect_state ()
3405
+ try :
3406
+ nonblocking_tls_client_connection .do_handshake ()
3407
+ except WantWriteError :
3408
+ assert True # Sanity check
3409
+ except :
3410
+ assert False # This should never happen (see the note above)
3411
+
3412
+ with monkeypatch .context () as mp_ctx :
3413
+ mp_ctx .setattr (_lib , "SSL_write" , patched_ssl_write )
3414
+ nonblocking_tls_client_connection .sendall (large_payload )
3415
+
3416
+ assert consume_garbage .garbage_consumed
3417
+
3418
+ # NOTE: Read the leftover data from the very last `SSL_write()`
3419
+ consume_payload (nonblocking_tls_server_connection )
3420
+
3421
+ assert consume_payload .consumed == large_payload
3422
+
3300
3423
def test_wrong_args (self ):
3301
3424
"""
3302
3425
When called with arguments other than a string argument for its first
0 commit comments