diff --git a/Cargo.toml b/Cargo.toml index 1d2a9b1..83d32f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ keywords = ["embedded", "async", "http", "no_std"] exclude = [".github"] [dependencies] -buffered-io = { version = "0.5.1" } +buffered-io = { version = "0.5.2" } embedded-io = { version = "0.6" } embedded-io-async = { version = "0.6" } embedded-nal-async = "0.7.0" @@ -27,7 +27,9 @@ defmt = { version = "0.3", optional = true } embedded-tls = { version = "0.17", default-features = false, optional = true } rand_chacha = { version = "0.3", default-features = false } nourl = "0.1.1" -esp-mbedtls = { git = "https://github.com/esp-rs/esp-mbedtls.git", features = ["async"], optional = true } +esp-mbedtls = { git = "https://github.com/esp-rs/esp-mbedtls.git", features = [ + "async", +], optional = true } [dev-dependencies] hyper = { version = "0.14.23", features = ["full"] } diff --git a/src/client.rs b/src/client.rs index 91f717c..2b007c4 100644 --- a/src/client.rs +++ b/src/client.rs @@ -311,31 +311,19 @@ where HttpConnection::Plain(c) => { let mut writer = ChunkedBodyWriter::new(c); body.write(&mut writer).await?; - writer.write_empty_chunk().await.map_err(|e| e.kind())?; + writer.terminate().await.map_err(|e| e.kind())?; } - HttpConnection::PlainBuffered(buffered_conn) => { - // Flush the buffered connection so that we can bypass it and rent its buffer - buffered_conn.flush().await.map_err(|e| e.kind())?; - let (conn, buf) = buffered_conn.bypass_with_buf().unwrap(); - - // Construct a new buffered writer that buffers _before_ the chunked body writer - let mut writer = BufferedWrite::new(ChunkedBodyWriter::new(conn), buf); + HttpConnection::PlainBuffered(buffered) => { + let (conn, buf, unwritten) = buffered.split(); + let mut writer = BufferedChunkedBodyWriter::new_with_data(conn, buf, unwritten); body.write(&mut writer).await?; - - // Flush the buffered writer and write the empty chunk to the chunked body writer - writer.flush().await.map_err(|e| e.kind())?; - writer - .bypass() - .unwrap() - .write_empty_chunk() - .await - .map_err(|e| e.kind())?; + writer.terminate().await.map_err(|e| e.kind())?; } #[cfg(any(feature = "embedded-tls", feature = "esp-mbedtls"))] HttpConnection::Tls(c) => { let mut writer = ChunkedBodyWriter::new(c); body.write(&mut writer).await?; - writer.write_empty_chunk().await.map_err(|e| e.kind())?; + writer.terminate().await.map_err(|e| e.kind())?; } #[cfg(all(not(feature = "embedded-tls"), not(feature = "esp-mbedtls")))] HttpConnection::Tls(_) => unreachable!(), diff --git a/src/request.rs b/src/request.rs index ae1b41f..68e4c8f 100644 --- a/src/request.rs +++ b/src/request.rs @@ -248,7 +248,7 @@ impl Method { } async fn write_str(c: &mut C, data: &str) -> Result<(), Error> { - c.write_all(data.as_bytes()).await.map_err(to_errorkind)?; + c.write_all(data.as_bytes()).await.map_err(|e| e.kind())?; Ok(()) } @@ -355,17 +355,35 @@ where } } -pub struct ChunkedBodyWriter(C, usize); +const fn hex_chars(number: usize) -> u32 { + if number == 0 { + 1 + } else { + (usize::BITS - number.leading_zeros()).div_ceil(4) + } +} + +fn write_chunked_header(buf: &mut [u8], chunk_len: usize) -> usize { + let mut hex = [0; 2 * size_of::()]; + hex::encode_to_slice(chunk_len.to_be_bytes(), &mut hex).unwrap(); + let leading_zeros = hex.iter().position(|x| *x != b'0').unwrap_or_default(); + let hex_chars = hex.len() - leading_zeros; + buf[..hex_chars].copy_from_slice(&hex[leading_zeros..]); + buf[hex_chars..hex_chars + 2].copy_from_slice(b"\r\n"); + hex_chars + 2 +} + +pub struct ChunkedBodyWriter(C); impl ChunkedBodyWriter where C: Write, { pub fn new(conn: C) -> Self { - Self(conn, 0) + Self(conn) } - pub async fn write_empty_chunk(&mut self) -> Result<(), C::Error> { + pub async fn terminate(&mut self) -> Result<(), C::Error> { self.0.write_all(b"0\r\n\r\n").await } } @@ -377,21 +395,16 @@ where type Error = embedded_io::ErrorKind; } -fn to_errorkind(e: E) -> embedded_io::ErrorKind { - e.kind() -} - impl Write for ChunkedBodyWriter where C: Write, { async fn write(&mut self, buf: &[u8]) -> Result { - self.write_all(buf).await.map_err(to_errorkind)?; + self.write_all(buf).await.map_err(|e| e.kind())?; Ok(buf.len()) } async fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> { - // Write chunk header let len = buf.len(); // Do not write an empty chunk as that will terminate the body @@ -400,19 +413,19 @@ where return Ok(()); } - let mut hex = [0; 2 * size_of::()]; - hex::encode_to_slice(len.to_be_bytes(), &mut hex).unwrap(); - let leading_zeros = hex.iter().position(|x| *x != b'0').unwrap_or_default(); - let (_, hex) = hex.split_at(leading_zeros); - self.0.write_all(hex).await.map_err(to_errorkind)?; - self.0.write_all(b"\r\n").await.map_err(to_errorkind)?; + // Write chunk header + let mut header_buf = [0; 2 * size_of::() + 2]; + let header_len = write_chunked_header(&mut header_buf, len); + self.0 + .write_all(&header_buf[..header_len]) + .await + .map_err(|e| e.kind())?; // Write chunk - self.0.write_all(buf).await.map_err(to_errorkind)?; - self.1 += len; + self.0.write_all(buf).await.map_err(|e| e.kind())?; - // Write newline - self.0.write_all(b"\r\n").await.map_err(to_errorkind)?; + // Write newline footer + self.0.write_all(b"\r\n").await.map_err(|e| e.kind())?; Ok(()) } @@ -421,10 +434,133 @@ where } } +pub struct BufferedChunkedBodyWriter<'a, C: Write> { + conn: C, + buf: &'a mut [u8], + header_pos: usize, + pos: usize, + max_header_size: usize, + max_footer_size: usize, +} + +impl<'a, C> BufferedChunkedBodyWriter<'a, C> +where + C: Write, +{ + pub fn new_with_data(conn: C, buf: &'a mut [u8], written: usize) -> Self { + let max_hex_chars = hex_chars(buf.len()); + let max_header_size = max_hex_chars as usize + 2; + let max_footer_size = 2; + assert!(buf.len() > max_header_size + max_footer_size); // There must be space for the chunk header and footer + Self { + conn, + buf, + header_pos: written, + pos: written + max_header_size, + max_header_size, + max_footer_size, + } + } + + pub async fn terminate(&mut self) -> Result<(), C::Error> { + if self.pos > self.header_pos + self.max_header_size { + self.finish_current_chunk(); + } + const EMPTY: &[u8; 5] = b"0\r\n\r\n"; + if self.header_pos + EMPTY.len() > self.buf.len() { + self.emit_finished_chunk().await?; + } + + self.buf[self.header_pos..self.header_pos + EMPTY.len()].copy_from_slice(EMPTY); + self.header_pos += EMPTY.len(); + self.pos = self.header_pos + self.max_header_size; + self.emit_finished_chunk().await + } + + fn append_current_chunk(&mut self, buf: &[u8]) -> usize { + let buffered = usize::min(buf.len(), self.buf.len() - self.max_footer_size - self.pos); + if buffered > 0 { + self.buf[self.pos..self.pos + buffered].copy_from_slice(&buf[..buffered]); + self.pos += buffered; + } + buffered + } + + fn finish_current_chunk(&mut self) { + // Write the header in the allocated position position + let chunk_len = self.pos - self.header_pos - self.max_header_size; + let header_buf = &mut self.buf[self.header_pos..self.header_pos + self.max_header_size]; + let header_len = write_chunked_header(header_buf, chunk_len); + + // Move the payload if the header length was not as large as it could possibly be + let spacing = self.max_header_size - header_len; + if spacing > 0 { + self.buf.copy_within( + self.header_pos + self.max_header_size..self.pos, + self.header_pos + header_len, + ); + self.pos -= spacing + } + + // Write newline footer after chunk payload + self.buf[self.pos..self.pos + 2].copy_from_slice(b"\r\n"); + self.pos += 2; + + self.header_pos = self.pos; + self.pos = self.header_pos + self.max_header_size; + } + + async fn emit_finished_chunk(&mut self) -> Result<(), C::Error> { + self.conn.write_all(&self.buf[..self.header_pos]).await?; + self.header_pos = 0; + self.pos = self.max_header_size; + Ok(()) + } +} + +impl ErrorType for BufferedChunkedBodyWriter<'_, C> +where + C: Write, +{ + type Error = embedded_io::ErrorKind; +} + +impl Write for BufferedChunkedBodyWriter<'_, C> +where + C: Write, +{ + async fn write(&mut self, buf: &[u8]) -> Result { + let written = self.append_current_chunk(buf); + if written < buf.len() { + self.finish_current_chunk(); + self.emit_finished_chunk().await.map_err(|e| e.kind())?; + } + Ok(written) + } + + async fn flush(&mut self) -> Result<(), Self::Error> { + if self.header_pos > 0 { + self.finish_current_chunk(); + self.emit_finished_chunk().await.map_err(|e| e.kind())?; + } + self.conn.flush().await.map_err(|e| e.kind()) + } +} + #[cfg(test)] mod tests { use super::*; + #[test] + fn hex_chars_values() { + assert_eq!(1, hex_chars(0)); + assert_eq!(1, hex_chars(1)); + assert_eq!(1, hex_chars(0xF)); + assert_eq!(2, hex_chars(0x10)); + assert_eq!(2, hex_chars(0xFF)); + assert_eq!(3, hex_chars(0x100)); + } + #[tokio::test] async fn basic_auth() { let mut buffer: Vec = Vec::new();