Skip to content

Commit 7df79f1

Browse files
committed
keep-alive: always read body to end
1 parent 69bed38 commit 7df79f1

12 files changed

+629
-107
lines changed

Cargo.toml

+1-5
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,8 @@ futures-core = "0.3.8"
2727
log = "0.4.11"
2828
pin-project = "1.0.2"
2929
async-channel = "1.5.1"
30+
async-dup = "1.2.2"
3031

3132
[dev-dependencies]
3233
pretty_assertions = "0.6.1"
3334
async-std = { version = "1.7.0", features = ["attributes"] }
34-
tempfile = "3.1.0"
35-
async-test = "1.0.0"
36-
duplexify = "1.2.2"
37-
async-dup = "1.2.2"
38-
async-channel = "1.5.1"

src/chunked/decoder.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ lazy_static::lazy_static! {
1919

2020
/// Decodes a chunked body according to
2121
/// https://tools.ietf.org/html/rfc7230#section-4.1
22-
pub(crate) struct ChunkedDecoder<R: Read> {
22+
#[derive(Debug)]
23+
pub struct ChunkedDecoder<R: Read> {
2324
/// The underlying stream
2425
inner: R,
2526
/// Buffer for the already read, but not yet parsed data.

src/read_notifier.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ impl<B> fmt::Debug for ReadNotifier<B> {
2626
}
2727
}
2828

29-
impl<B: BufRead> ReadNotifier<B> {
29+
impl<B: Read> ReadNotifier<B> {
3030
pub(crate) fn new(reader: B, sender: Sender<()>) -> Self {
3131
Self {
3232
reader,

src/server/body_reader.rs

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
use crate::chunked::ChunkedDecoder;
2+
use async_dup::{Arc, Mutex};
3+
use async_std::io::{BufReader, Read, Take};
4+
use async_std::task::{Context, Poll};
5+
use std::{fmt::Debug, io, pin::Pin};
6+
7+
pub enum BodyReader<IO: Read + Unpin> {
8+
Chunked(Arc<Mutex<ChunkedDecoder<BufReader<IO>>>>),
9+
Fixed(Arc<Mutex<Take<BufReader<IO>>>>),
10+
None,
11+
}
12+
13+
impl<IO: Read + Unpin> Debug for BodyReader<IO> {
14+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15+
match self {
16+
BodyReader::Chunked(_) => f.write_str("BodyReader::Chunked"),
17+
BodyReader::Fixed(_) => f.write_str("BodyReader::Fixed"),
18+
BodyReader::None => f.write_str("BodyReader::None"),
19+
}
20+
}
21+
}
22+
23+
impl<IO: Read + Unpin> Read for BodyReader<IO> {
24+
fn poll_read(
25+
self: Pin<&mut Self>,
26+
cx: &mut Context<'_>,
27+
buf: &mut [u8],
28+
) -> Poll<io::Result<usize>> {
29+
match &*self {
30+
BodyReader::Chunked(r) => Pin::new(&mut *r.lock()).poll_read(cx, buf),
31+
BodyReader::Fixed(r) => Pin::new(&mut *r.lock()).poll_read(cx, buf),
32+
BodyReader::None => Poll::Ready(Ok(0)),
33+
}
34+
}
35+
}

src/server/decode.rs

+24-19
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
33
use std::str::FromStr;
44

5+
use async_dup::{Arc, Mutex};
56
use async_std::io::{BufReader, Read, Write};
67
use async_std::{prelude::*, task};
78
use http_types::headers::{CONTENT_LENGTH, EXPECT, TRANSFER_ENCODING};
89
use http_types::{ensure, ensure_eq, format_err};
910
use http_types::{Body, Method, Request, Url};
1011

12+
use super::body_reader::BodyReader;
1113
use crate::chunked::ChunkedDecoder;
1214
use crate::read_notifier::ReadNotifier;
1315
use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};
@@ -21,7 +23,7 @@ const CONTINUE_HEADER_VALUE: &str = "100-continue";
2123
const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
2224

2325
/// Decode an HTTP request on the server.
24-
pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<Request>>
26+
pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<(Request, BodyReader<IO>)>>
2527
where
2628
IO: Read + Write + Clone + Send + Sync + Unpin + 'static,
2729
{
@@ -108,26 +110,29 @@ where
108110
}
109111

110112
// Check for Transfer-Encoding
111-
if let Some(encoding) = transfer_encoding {
112-
if encoding.last().as_str() == "chunked" {
113-
let trailer_sender = req.send_trailers();
114-
let reader = ChunkedDecoder::new(reader, trailer_sender);
115-
let reader = BufReader::new(reader);
116-
let reader = ReadNotifier::new(reader, body_read_sender);
117-
req.set_body(Body::from_reader(reader, None));
118-
return Ok(Some(req));
119-
}
120-
// Fall through to Content-Length
121-
}
122-
123-
// Check for Content-Length.
124-
if let Some(len) = content_length {
113+
if transfer_encoding
114+
.map(|te| te.as_str().eq_ignore_ascii_case("chunked"))
115+
.unwrap_or(false)
116+
{
117+
let trailer_sender = req.send_trailers();
118+
let reader = ChunkedDecoder::new(reader, trailer_sender);
119+
let reader = Arc::new(Mutex::new(reader));
120+
let reader_clone = reader.clone();
121+
let reader = ReadNotifier::new(reader, body_read_sender);
122+
let reader = BufReader::new(reader);
123+
req.set_body(Body::from_reader(reader, None));
124+
return Ok(Some((req, BodyReader::Chunked(reader_clone))));
125+
} else if let Some(len) = content_length {
125126
let len = len.last().as_str().parse::<usize>()?;
126-
let reader = ReadNotifier::new(reader.take(len as u64), body_read_sender);
127-
req.set_body(Body::from_reader(reader, Some(len)));
127+
let reader = Arc::new(Mutex::new(reader.take(len as u64)));
128+
req.set_body(Body::from_reader(
129+
BufReader::new(ReadNotifier::new(reader.clone(), body_read_sender)),
130+
Some(len),
131+
));
132+
Ok(Some((req, BodyReader::Fixed(reader))))
133+
} else {
134+
Ok(Some((req, BodyReader::None)))
128135
}
129-
130-
Ok(Some(req))
131136
}
132137

133138
fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<Url> {

src/server/mod.rs

+91-20
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
//! Process HTTP connections on the server.
22
3-
use std::time::Duration;
4-
53
use async_std::future::{timeout, Future, TimeoutError};
64
use async_std::io::{self, Read, Write};
75
use http_types::headers::{CONNECTION, UPGRADE};
86
use http_types::upgrade::Connection;
97
use http_types::{Request, Response, StatusCode};
10-
8+
use std::{marker::PhantomData, time::Duration};
9+
mod body_reader;
1110
mod decode;
1211
mod encode;
1312

@@ -38,14 +37,14 @@ where
3837
F: Fn(Request) -> Fut,
3938
Fut: Future<Output = http_types::Result<Response>>,
4039
{
41-
accept_with_opts(io, endpoint, Default::default()).await
40+
Server::new(io, endpoint).accept().await
4241
}
4342

4443
/// Accept a new incoming HTTP/1.1 connection.
4544
///
4645
/// Supports `KeepAlive` requests by default.
4746
pub async fn accept_with_opts<RW, F, Fut>(
48-
mut io: RW,
47+
io: RW,
4948
endpoint: F,
5049
opts: ServerOptions,
5150
) -> http_types::Result<()>
@@ -54,35 +53,99 @@ where
5453
F: Fn(Request) -> Fut,
5554
Fut: Future<Output = http_types::Result<Response>>,
5655
{
57-
loop {
56+
Server::new(io, endpoint).with_opts(opts).accept().await
57+
}
58+
59+
/// struct for server
60+
#[derive(Debug)]
61+
pub struct Server<RW, F, Fut> {
62+
io: RW,
63+
endpoint: F,
64+
opts: ServerOptions,
65+
_phantom: PhantomData<Fut>,
66+
}
67+
68+
/// An enum that represents whether the server should accept a subsequent request
69+
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
70+
pub enum ConnectionStatus {
71+
/// The server should not accept another request
72+
Close,
73+
74+
/// The server may accept another request
75+
KeepAlive,
76+
}
77+
78+
impl<RW, F, Fut> Server<RW, F, Fut>
79+
where
80+
RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
81+
F: Fn(Request) -> Fut,
82+
Fut: Future<Output = http_types::Result<Response>>,
83+
{
84+
/// builds a new server
85+
pub fn new(io: RW, endpoint: F) -> Self {
86+
Self {
87+
io,
88+
endpoint,
89+
opts: Default::default(),
90+
_phantom: PhantomData,
91+
}
92+
}
93+
94+
/// with opts
95+
pub fn with_opts(mut self, opts: ServerOptions) -> Self {
96+
self.opts = opts;
97+
self
98+
}
99+
100+
/// accept in a loop
101+
pub async fn accept(&mut self) -> http_types::Result<()> {
102+
while ConnectionStatus::KeepAlive == self.accept_one().await? {}
103+
Ok(())
104+
}
105+
106+
/// accept one request
107+
pub async fn accept_one(&mut self) -> http_types::Result<ConnectionStatus>
108+
where
109+
RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
110+
F: Fn(Request) -> Fut,
111+
Fut: Future<Output = http_types::Result<Response>>,
112+
{
58113
// Decode a new request, timing out if this takes longer than the timeout duration.
59-
let fut = decode(io.clone());
114+
let fut = decode(self.io.clone());
60115

61-
let req = if let Some(timeout_duration) = opts.headers_timeout {
116+
let (req, mut body) = if let Some(timeout_duration) = self.opts.headers_timeout {
62117
match timeout(timeout_duration, fut).await {
63118
Ok(Ok(Some(r))) => r,
64-
Ok(Ok(None)) | Err(TimeoutError { .. }) => break, /* EOF or timeout */
119+
Ok(Ok(None)) | Err(TimeoutError { .. }) => return Ok(ConnectionStatus::Close), /* EOF or timeout */
65120
Ok(Err(e)) => return Err(e),
66121
}
67122
} else {
68123
match fut.await? {
69124
Some(r) => r,
70-
None => break, /* EOF */
125+
None => return Ok(ConnectionStatus::Close), /* EOF */
71126
}
72127
};
73128

74129
let has_upgrade_header = req.header(UPGRADE).is_some();
75-
let connection_header_is_upgrade = req
130+
let connection_header_as_str = req
76131
.header(CONNECTION)
77-
.map(|connection| connection.as_str().eq_ignore_ascii_case("upgrade"))
78-
.unwrap_or(false);
132+
.map(|connection| connection.as_str())
133+
.unwrap_or("");
134+
135+
let connection_header_is_upgrade = connection_header_as_str.eq_ignore_ascii_case("upgrade");
136+
let mut close_connection = connection_header_as_str.eq_ignore_ascii_case("close");
79137

80138
let upgrade_requested = has_upgrade_header && connection_header_is_upgrade;
81139

82140
let method = req.method();
83141

84142
// Pass the request to the endpoint and encode the response.
85-
let mut res = endpoint(req).await?;
143+
let mut res = (self.endpoint)(req).await?;
144+
145+
close_connection |= res
146+
.header(CONNECTION)
147+
.map(|c| c.as_str().eq_ignore_ascii_case("close"))
148+
.unwrap_or(false);
86149

87150
let upgrade_provided = res.status() == StatusCode::SwitchingProtocols && res.has_upgrade();
88151

@@ -94,14 +157,22 @@ where
94157

95158
let mut encoder = Encoder::new(res, method);
96159

97-
// Stream the response to the writer.
98-
io::copy(&mut encoder, &mut io).await?;
160+
let bytes_written = io::copy(&mut encoder, &mut self.io).await?;
161+
log::trace!("wrote {} response bytes", bytes_written);
162+
163+
let body_bytes_discarded = io::copy(&mut body, &mut io::sink()).await?;
164+
log::trace!(
165+
"discarded {} unread request body bytes",
166+
body_bytes_discarded
167+
);
99168

100169
if let Some(upgrade_sender) = upgrade_sender {
101-
upgrade_sender.send(Connection::new(io.clone())).await;
102-
return Ok(());
170+
upgrade_sender.send(Connection::new(self.io.clone())).await;
171+
return Ok(ConnectionStatus::Close);
172+
} else if close_connection {
173+
Ok(ConnectionStatus::Close)
174+
} else {
175+
Ok(ConnectionStatus::KeepAlive)
103176
}
104177
}
105-
106-
Ok(())
107178
}

0 commit comments

Comments
 (0)