@@ -24,10 +24,13 @@ def _or_inf(value: typing.Optional[float]) -> float:
24
24
25
25
26
26
class Stream (BaseStream ):
27
- def __init__ (self , stream : trio .abc .Stream , timeout : TimeoutConfig ) -> None :
27
+ def __init__ (
28
+ self ,
29
+ stream : typing .Union [trio .SocketStream , trio .SSLStream ],
30
+ timeout : TimeoutConfig ,
31
+ ) -> None :
28
32
self .stream = stream
29
33
self .timeout = timeout
30
- self .is_eof = False
31
34
self .write_buffer = b""
32
35
self .write_lock = trio .Lock ()
33
36
@@ -54,18 +57,18 @@ async def read(
54
57
read_timeout = _or_inf (timeout .read_timeout if should_raise else 0.01 )
55
58
56
59
with trio .move_on_after (read_timeout ):
57
- data = await self .stream .receive_some (max_bytes = n )
58
- # b"" is the expected EOF message for Trio.
59
- # The other case is an edge case that occurs with uvicorn+httptools.
60
- if data == b"" or data .endswith (b"0\r \n \r \n " ):
61
- self .is_eof = True
62
- return data
60
+ return await self .stream .receive_some (max_bytes = n )
63
61
64
62
if should_raise :
65
63
raise ReadTimeout () from None
66
64
67
65
def is_connection_dropped (self ) -> bool :
68
- return self .is_eof
66
+ stream = self .stream
67
+ # Peek through any SSLStream wrappers to get the underlying SocketStream.
68
+ while hasattr (stream , "transport_stream" ):
69
+ stream = stream .transport_stream
70
+ assert isinstance (stream , trio .SocketStream )
71
+ return not stream .socket .is_readable ()
69
72
70
73
def write_no_block (self , data : bytes ) -> None :
71
74
self .write_buffer += data
0 commit comments