@@ -308,7 +308,7 @@ def loopback_server_factory(socket, version=SSLv23_METHOD):
308
308
return server
309
309
310
310
311
- def loopback (server_factory = None , client_factory = None ):
311
+ def loopback (server_factory = None , client_factory = None , blocking = True ):
312
312
"""
313
313
Create a connected socket pair and force two connected SSL sockets
314
314
to talk to each other via memory BIOs.
@@ -324,8 +324,8 @@ def loopback(server_factory=None, client_factory=None):
324
324
325
325
handshake (client , server )
326
326
327
- server .setblocking (True )
328
- client .setblocking (True )
327
+ server .setblocking (blocking )
328
+ client .setblocking (blocking )
329
329
return server , client
330
330
331
331
@@ -3131,11 +3131,131 @@ def test_memoryview_really_doesnt_overfill(self):
3131
3131
self ._doesnt_overfill_test (_make_memoryview )
3132
3132
3133
3133
3134
+ @pytest .fixture
3135
+ def nonblocking_tls_connections_pair ():
3136
+ """Return a non-blocking TLS loopback connections pair."""
3137
+ return loopback (blocking = False )
3138
+
3139
+
3140
+ @pytest .fixture
3141
+ def nonblocking_tls_server_connection (nonblocking_tls_connections_pair ):
3142
+ """Return a non-blocking TLS server socket connected to loopback."""
3143
+ return nonblocking_tls_connections_pair [0 ]
3144
+
3145
+
3146
+ @pytest .fixture
3147
+ def nonblocking_tls_client_connection (nonblocking_tls_connections_pair ):
3148
+ """Return a non-blocking TLS client socket connected to loopback."""
3149
+ return nonblocking_tls_connections_pair [1 ]
3150
+
3151
+
3134
3152
class TestConnectionSendall (object ):
3135
3153
"""
3136
3154
Tests for `Connection.sendall`.
3137
3155
"""
3138
3156
3157
+ def test_want_write (
3158
+ self ,
3159
+ monkeypatch ,
3160
+ nonblocking_tls_server_connection ,
3161
+ nonblocking_tls_client_connection ,
3162
+ ):
3163
+ msg = b"x"
3164
+ garbage_size = 1024 * 1024 * 64
3165
+ large_payload = b"p" * garbage_size * 2
3166
+ payload_size = len (large_payload )
3167
+
3168
+ sent_garbage_size = 0
3169
+ try :
3170
+ sent_garbage_size += nonblocking_tls_client_connection .send (
3171
+ msg * garbage_size ,
3172
+ )
3173
+ except WantWriteError :
3174
+ pass
3175
+ for i in range (garbage_size ):
3176
+ try :
3177
+ sent_garbage_size += nonblocking_tls_client_connection .send (
3178
+ msg ,
3179
+ )
3180
+ except WantWriteError :
3181
+ break
3182
+ else :
3183
+ pytest .fail (
3184
+ "Failed to fill socket buffer, cannot test "
3185
+ "'want write' in `sendall()`"
3186
+ )
3187
+ garbage_payload = sent_garbage_size * msg
3188
+
3189
+
3190
+ def consume_garbage (conn ):
3191
+ assert patched_ssl_write .want_write_counter >= 1
3192
+ assert not consume_garbage .garbage_consumed
3193
+
3194
+ while len (consume_garbage .consumed ) < sent_garbage_size :
3195
+ try :
3196
+ consume_garbage .consumed += conn .recv (
3197
+ sent_garbage_size - len (consume_garbage .consumed ),
3198
+ )
3199
+ except WantReadError :
3200
+ pass
3201
+
3202
+ assert consume_garbage .consumed == garbage_payload
3203
+
3204
+ consume_garbage .garbage_consumed = True
3205
+
3206
+ consume_garbage .garbage_consumed = False
3207
+ consume_garbage .consumed = b""
3208
+
3209
+ def consume_payload (conn ):
3210
+ try :
3211
+ consume_payload .consumed += conn .recv (payload_size )
3212
+ except WantReadError :
3213
+ pass
3214
+ consume_payload .consumed = b""
3215
+
3216
+ original_ssl_write = _lib .SSL_write
3217
+ def patched_ssl_write (ctx , data , size ):
3218
+ write_result = original_ssl_write (ctx , data , size )
3219
+ try :
3220
+ nonblocking_tls_client_connection ._raise_ssl_error (
3221
+ ctx , write_result ,
3222
+ )
3223
+ except WantWriteError :
3224
+ patched_ssl_write .want_write_counter += 1
3225
+ consume_data_on_server = (
3226
+ consume_payload if consume_garbage .garbage_consumed
3227
+ else consume_garbage
3228
+ )
3229
+
3230
+ consume_data_on_server (nonblocking_tls_server_connection )
3231
+ # NOTE: We don't re-raise it as the calling code will do
3232
+ # NOTE: the same after the call.
3233
+ return write_result
3234
+
3235
+ patched_ssl_write .want_write_counter = 0
3236
+
3237
+ # NOTE: Make the client think it needs a handshake so that it'll
3238
+ # NOTE: attempt to `do_handshake()` on the next `SSL_write()`
3239
+ # NOTE: that originates from `sendall()`:
3240
+ nonblocking_tls_client_connection .set_connect_state ()
3241
+ try :
3242
+ nonblocking_tls_client_connection .do_handshake ()
3243
+ except WantWriteError :
3244
+ assert True # Sanity check
3245
+ except :
3246
+ assert False # This should never happen (see the note above)
3247
+
3248
+ with monkeypatch .context () as mp_ctx :
3249
+ mp_ctx .setattr (_lib , "SSL_write" , patched_ssl_write )
3250
+ nonblocking_tls_client_connection .sendall (large_payload )
3251
+
3252
+ assert consume_garbage .garbage_consumed
3253
+
3254
+ # NOTE: Read the leftover data from the very last `SSL_write()`
3255
+ consume_payload (nonblocking_tls_server_connection )
3256
+
3257
+ assert consume_payload .consumed == large_payload
3258
+
3139
3259
def test_wrong_args (self ):
3140
3260
"""
3141
3261
When called with arguments other than a string argument for its first
0 commit comments