diff --git a/crates/test-programs/src/bin/p3_http_outbound_request_content_length.rs b/crates/test-programs/src/bin/p3_http_outbound_request_content_length.rs index cf01e1cde496..2bae98ec7656 100644 --- a/crates/test-programs/src/bin/p3_http_outbound_request_content_length.rs +++ b/crates/test-programs/src/bin/p3_http_outbound_request_content_length.rs @@ -1,4 +1,3 @@ -use anyhow::Context as _; use futures::join; use test_programs::p3::wasi::http::handler; use test_programs::p3::wasi::http::types::{ErrorCode, Headers, Method, Request, Scheme, Trailers}; @@ -47,13 +46,9 @@ impl test_programs::p3::exports::wasi::cli::run::Guest for Component { println!("writing enough"); { let (request, mut contents_tx, trailers_tx, transmit) = make_request(); - let (handle, (), ()) = join!( - async { - let res = handler::handle(request) - .await - .context("failed to send request")?; - anyhow::Ok(res) - }, + let (handle, transmit, ()) = join!( + async { handler::handle(request).await }, + async { transmit.await }, async { let remaining = contents_tx.write_all(b"long enough".to_vec()).await; assert!( @@ -64,12 +59,9 @@ impl test_programs::p3::exports::wasi::cli::run::Guest for Component { trailers_tx.write(Ok(None)).await.unwrap(); drop(contents_tx); }, - async { - transmit.await.unwrap(); - }, ); - let res = handle.unwrap(); - drop(res); + let _res = handle.expect("failed to send request"); + transmit.expect("failed to transmit request"); } println!("writing too little"); @@ -89,8 +81,11 @@ impl test_programs::p3::exports::wasi::cli::run::Guest for Component { trailers_tx.write(Ok(None)).await.unwrap(); }, ); - let res = handle.unwrap(); - drop(res); + let err = handle.expect_err("should have failed to send request"); + assert!( + matches!(err, ErrorCode::HttpProtocolError), + "unexpected error: {err:#?}" + ); let err = transmit.expect_err("request transmission should have failed"); assert!( matches!(err, ErrorCode::HttpRequestBodySize(Some(3))), @@ -106,18 +101,15 @@ impl test_programs::p3::exports::wasi::cli::run::Guest for Component { async { transmit.await }, async { let remaining = contents_tx.write_all(b"more than 11 bytes".to_vec()).await; - assert!( - remaining.is_empty(), - "{}", - String::from_utf8_lossy(&remaining) - ); + assert_eq!(String::from_utf8_lossy(&remaining), "more than 11 bytes",); drop(contents_tx); _ = trailers_tx.write(Ok(None)).await; }, ); - let res = handle.unwrap(); - drop(res); + // The the error returned by `handle` in this case is non-deterministic, + // so just assert that it fails + let _err = handle.expect_err("should have failed to send request"); let err = transmit.expect_err("request transmission should have failed"); assert!( matches!(err, ErrorCode::HttpRequestBodySize(Some(18))), diff --git a/crates/test-programs/src/bin/p3_http_outbound_request_response_build.rs b/crates/test-programs/src/bin/p3_http_outbound_request_response_build.rs index d616577efaac..0213986ad4d6 100644 --- a/crates/test-programs/src/bin/p3_http_outbound_request_response_build.rs +++ b/crates/test-programs/src/bin/p3_http_outbound_request_response_build.rs @@ -25,11 +25,9 @@ impl test_programs::p3::exports::wasi::cli::run::Guest for Component { request .set_authority(Some("www.example.com")) .expect("setting authority"); - let (remaining, ()) = - futures::join!(contents_tx.write_all(b"request-body".to_vec()), async { - drop(request); - },); - assert!(!remaining.is_empty()); + drop(request); + let remaining = contents_tx.write_all(b"request-body".to_vec()).await; + assert_eq!(String::from_utf8_lossy(&remaining), "request-body"); } { let headers = Headers::from_list(&[( diff --git a/crates/wasi-http/src/lib.rs b/crates/wasi-http/src/lib.rs index bf0088afb80d..38a3fd18dec1 100644 --- a/crates/wasi-http/src/lib.rs +++ b/crates/wasi-http/src/lib.rs @@ -230,7 +230,6 @@ pub mod types; pub mod bindings; #[cfg(feature = "p3")] -#[expect(missing_docs, reason = "work in progress")] // TODO: add docs pub mod p3; pub use crate::error::{ diff --git a/crates/wasi-http/src/p3/bindings.rs b/crates/wasi-http/src/p3/bindings.rs index 01eaa0aedacc..b7fec724c11c 100644 --- a/crates/wasi-http/src/p3/bindings.rs +++ b/crates/wasi-http/src/p3/bindings.rs @@ -7,6 +7,8 @@ mod generated { world: "wasi:http/proxy", imports: { "wasi:http/handler/[async]handle": async | store | trappable | tracing, + "wasi:http/types/[drop]request": store | trappable | tracing, + "wasi:http/types/[drop]response": store | trappable | tracing, "wasi:http/types/[method]request.consume-body": async | store | trappable | tracing, "wasi:http/types/[method]response.consume-body": async | store | trappable | tracing, "wasi:http/types/[static]request.new": async | store | trappable | tracing, @@ -22,14 +24,13 @@ mod generated { }, trappable_error_type: { "wasi:http/types/error-code" => crate::p3::HttpError, + "wasi:http/types/header-error" => crate::p3::HeaderError, + "wasi:http/types/request-options-error" => crate::p3::RequestOptionsError, }, }); mod with { - /// The concrete type behind a `wasi:http/types/fields` resource. pub type Fields = crate::p3::MaybeMutable; - - /// The concrete type behind a `wasi:http/types/request-options` resource. pub type RequestOptions = crate::p3::MaybeMutable; } } diff --git a/crates/wasi-http/src/p3/body.rs b/crates/wasi-http/src/p3/body.rs index 7ea96eaffda4..6af49de504d0 100644 --- a/crates/wasi-http/src/p3/body.rs +++ b/crates/wasi-http/src/p3/body.rs @@ -1,18 +1,23 @@ -use crate::p3::WasiHttpCtxView; -use crate::p3::bindings::http::types::{ErrorCode, Trailers}; +use crate::p3::bindings::http::types::{ErrorCode, Fields, Trailers}; +use crate::p3::{WasiHttp, WasiHttpCtxView}; use anyhow::Context as _; use bytes::Bytes; +use core::num::NonZeroUsize; use core::pin::Pin; use core::task::{Context, Poll, ready}; use http::HeaderMap; +use http_body::Body as _; use http_body_util::combinators::BoxBody; +use std::io::Cursor; use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; use tokio_util::sync::PollSender; use wasmtime::component::{ - FutureConsumer, FutureReader, Resource, Source, StreamConsumer, StreamReader, StreamResult, + Access, Destination, FutureConsumer, FutureReader, Resource, Source, StreamConsumer, + StreamProducer, StreamReader, StreamResult, }; use wasmtime::{AsContextMut, StoreContextMut}; +use wasmtime_wasi::p3::{FutureOneshotProducer, StreamEmptyProducer}; /// The concrete type behind a `wasi:http/types/body` resource. pub(crate) enum Body { @@ -27,6 +32,7 @@ pub(crate) enum Body { }, /// Body constructed by the host. Host { + /// The [`http_body::Body`] body: BoxBody, /// Channel, on which transmission result will be written result_tx: oneshot::Sender> + Send>>, @@ -35,8 +41,148 @@ pub(crate) enum Body { Consumed, } -pub(crate) struct GuestBodyConsumer { - pub(crate) tx: PollSender, +impl Body { + pub(crate) fn consume( + self, + mut store: Access<'_, T, WasiHttp>, + getter: fn(&mut T) -> WasiHttpCtxView<'_>, + ) -> Result< + ( + StreamReader, + FutureReader>, ErrorCode>>, + ), + (), + > { + match self { + Body::Guest { + contents_rx: Some(contents_rx), + trailers_rx, + result_tx, + } => { + // TODO: Use a result specified by the caller + // https://github.com/WebAssembly/wasi-http/issues/176 + _ = result_tx.send(Box::new(async { Ok(()) })); + Ok((contents_rx, trailers_rx)) + } + Body::Guest { + contents_rx: None, + trailers_rx, + result_tx, + } => { + let instance = store.instance(); + // TODO: Use a result specified by the caller + // https://github.com/WebAssembly/wasi-http/issues/176 + _ = result_tx.send(Box::new(async { Ok(()) })); + Ok(( + StreamReader::new(instance, &mut store, StreamEmptyProducer::default()), + trailers_rx, + )) + } + Body::Host { body, result_tx } => { + let instance = store.instance(); + // TODO: Use a result specified by the caller + // https://github.com/WebAssembly/wasi-http/issues/176 + _ = result_tx.send(Box::new(async { Ok(()) })); + let (trailers_tx, trailers_rx) = oneshot::channel(); + Ok(( + StreamReader::new( + instance, + &mut store, + HostBodyStreamProducer { + body, + trailers: Some(trailers_tx), + getter, + }, + ), + FutureReader::new( + instance, + &mut store, + FutureOneshotProducer::from(trailers_rx), + ), + )) + } + Body::Consumed => Err(()), + } + } + + pub(crate) fn drop(self, mut store: impl AsContextMut) { + if let Body::Guest { + contents_rx, + mut trailers_rx, + .. + } = self + { + if let Some(mut contents_rx) = contents_rx { + contents_rx.close(&mut store); + } + trailers_rx.close(store); + } + } +} + +pub(crate) enum GuestBodyKind { + Request, + Response, +} + +/// Represents `Content-Length` limit and state +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +struct ContentLength { + /// Limit of bytes to be sent + limit: u64, + /// Number of bytes sent + sent: u64, +} + +impl ContentLength { + /// Constructs new [ContentLength] + fn new(limit: u64) -> Self { + Self { limit, sent: 0 } + } +} + +struct GuestBodyConsumer { + contents_tx: PollSender>, + result_tx: Option>>, + content_length: Option, + kind: GuestBodyKind, + // `true` when the other side of `contents_tx` was unexpectedly closed + closed: bool, +} + +impl GuestBodyConsumer { + fn body_size_error(&self, n: Option) -> ErrorCode { + match self.kind { + GuestBodyKind::Request => ErrorCode::HttpRequestBodySize(n), + GuestBodyKind::Response => ErrorCode::HttpResponseBodySize(n), + } + } + + // Sends the corresponding error constructed by [Self::body_size_error] on both + // error channels. + // [`PollSender::poll_reserve`] on `contents_tx` must have succeeed prior to this being called. + fn send_body_size_error(&mut self, n: Option) { + if let Some(result_tx) = self.result_tx.take() { + _ = result_tx.send(Err(self.body_size_error(n))); + _ = self.contents_tx.send_item(Err(self.body_size_error(n))); + } + } +} + +impl Drop for GuestBodyConsumer { + fn drop(&mut self) { + if let Some(result_tx) = self.result_tx.take() { + if let Some(ContentLength { limit, sent }) = self.content_length { + if !self.closed && limit != sent { + _ = result_tx.send(Err(self.body_size_error(Some(sent)))); + self.contents_tx.abort_send(); + if let Some(tx) = self.contents_tx.get_ref() { + _ = tx.try_send(Err(self.body_size_error(Some(sent)))) + } + } + } + } + } } impl StreamConsumer for GuestBodyConsumer { @@ -49,20 +195,40 @@ impl StreamConsumer for GuestBodyConsumer { src: Source, finish: bool, ) -> Poll> { - match self.tx.poll_reserve(cx) { + debug_assert!(!self.closed); + match self.contents_tx.poll_reserve(cx) { Poll::Ready(Ok(())) => { let mut src = src.as_direct(store); - let buf = Bytes::copy_from_slice(src.remaining()); + let buf = src.remaining(); + if let Some(ContentLength { limit, sent }) = self.content_length.as_mut() { + let Some(n) = buf.len().try_into().ok().and_then(|n| sent.checked_add(n)) + else { + self.send_body_size_error(None); + return Poll::Ready(Ok(StreamResult::Dropped)); + }; + if n > *limit { + self.send_body_size_error(Some(n)); + return Poll::Ready(Ok(StreamResult::Dropped)); + } + *sent = n; + } + let buf = Bytes::copy_from_slice(buf); let n = buf.len(); - match self.tx.send_item(buf) { + match self.contents_tx.send_item(Ok(buf)) { Ok(()) => { src.mark_read(n); Poll::Ready(Ok(StreamResult::Completed)) } - Err(..) => Poll::Ready(Ok(StreamResult::Dropped)), + Err(..) => { + self.closed = true; + Poll::Ready(Ok(StreamResult::Dropped)) + } } } - Poll::Ready(Err(..)) => Poll::Ready(Ok(StreamResult::Dropped)), + Poll::Ready(Err(..)) => { + self.closed = true; + Poll::Ready(Ok(StreamResult::Dropped)) + } Poll::Pending if finish => Poll::Ready(Ok(StreamResult::Cancelled)), Poll::Pending => Poll::Pending, } @@ -70,17 +236,20 @@ impl StreamConsumer for GuestBodyConsumer { } pub(crate) struct GuestBody { - pub(crate) contents_rx: Option>, - pub(crate) trailers_rx: - Option>, ErrorCode>>>, + contents_rx: Option>>, + trailers_rx: Option>, ErrorCode>>>, + content_length: Option, } impl GuestBody { - pub fn new( + pub(crate) fn new( mut store: impl AsContextMut, contents_rx: Option>, trailers_rx: FutureReader>, ErrorCode>>, - getter: for<'a> fn(&'a mut T) -> WasiHttpCtxView<'a>, + result_tx: oneshot::Sender>, + content_length: Option, + kind: GuestBodyKind, + getter: fn(&mut T) -> WasiHttpCtxView<'_>, ) -> Self { let (trailers_http_tx, trailers_http_rx) = oneshot::channel(); trailers_rx.pipe( @@ -95,7 +264,11 @@ impl GuestBody { rx.pipe( store, GuestBodyConsumer { - tx: PollSender::new(http_tx), + contents_tx: PollSender::new(http_tx), + result_tx: Some(result_tx), + content_length: content_length.map(ContentLength::new), + kind, + closed: false, }, ); http_rx @@ -103,6 +276,7 @@ impl GuestBody { Self { trailers_rx: Some(trailers_http_rx), contents_rx, + content_length, } } } @@ -116,8 +290,18 @@ impl http_body::Body for GuestBody { cx: &mut Context<'_>, ) -> Poll, Self::Error>>> { if let Some(contents_rx) = self.contents_rx.as_mut() { - while let Some(buf) = ready!(contents_rx.poll_recv(cx)) { - return Poll::Ready(Some(Ok(http_body::Frame::data(buf)))); + while let Some(res) = ready!(contents_rx.poll_recv(cx)) { + match res { + Ok(buf) => { + if let Some(n) = self.content_length.as_mut() { + *n = n.saturating_sub(buf.len().try_into().unwrap_or(u64::MAX)); + } + return Poll::Ready(Some(Ok(http_body::Frame::data(buf)))); + } + Err(err) => { + return Poll::Ready(Some(Err(err))); + } + } } self.contents_rx = None; } @@ -140,7 +324,10 @@ impl http_body::Body for GuestBody { fn is_end_stream(&self) -> bool { if let Some(contents_rx) = self.contents_rx.as_ref() { - if !contents_rx.is_empty() || !contents_rx.is_closed() { + if !contents_rx.is_empty() + || !contents_rx.is_closed() + || self.content_length.is_some_and(|n| n > 0) + { return false; } } @@ -153,8 +340,11 @@ impl http_body::Body for GuestBody { } fn size_hint(&self) -> http_body::SizeHint { - // TODO: use content-length - http_body::SizeHint::default() + if let Some(n) = self.content_length { + http_body::SizeHint::with_exact(n) + } else { + http_body::SizeHint::default() + } } } @@ -184,7 +374,7 @@ impl http_body::Body for ConsumedBody { pub(crate) struct GuestTrailerConsumer { pub(crate) tx: Option>, ErrorCode>>>, - pub(crate) getter: for<'a> fn(&'a mut T) -> WasiHttpCtxView<'a>, + pub(crate) getter: fn(&mut T) -> WasiHttpCtxView<'_>, } impl FutureConsumer for GuestTrailerConsumer @@ -194,7 +384,7 @@ where type Item = Result>, ErrorCode>; fn poll_consume( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, _: &mut Context<'_>, mut store: StoreContextMut, mut source: Source<'_, Self::Item>, @@ -202,61 +392,107 @@ where ) -> Poll> { let value = &mut None; source.read(store.as_context_mut(), value)?; - let res = value.take().unwrap(); - let me = self.get_mut(); - match res { + let res = match value.take().unwrap() { Ok(Some(trailers)) => { - let WasiHttpCtxView { table, .. } = (me.getter)(store.data_mut()); + let WasiHttpCtxView { table, .. } = (self.getter)(store.data_mut()); let trailers = table .delete(trailers) .context("failed to delete trailers")?; - _ = me.tx.take().unwrap().send(Ok(Some(Arc::from(trailers)))); - } - Ok(None) => { - _ = me.tx.take().unwrap().send(Ok(None)); - } - Err(err) => { - _ = me.tx.take().unwrap().send(Err(err)); + Ok(Some(Arc::from(trailers))) } - } + Ok(None) => Ok(None), + Err(err) => Err(err), + }; + _ = self.tx.take().unwrap().send(res); Poll::Ready(Ok(())) } } -pub(crate) struct IncomingResponseBody { - pub incoming: hyper::body::Incoming, - pub timeout: tokio::time::Interval, +struct HostBodyStreamProducer { + body: BoxBody, + trailers: Option>, ErrorCode>>>, + getter: fn(&mut T) -> WasiHttpCtxView<'_>, } -impl http_body::Body for IncomingResponseBody { - type Data = ::Data; - type Error = ErrorCode; +impl Drop for HostBodyStreamProducer { + fn drop(&mut self) { + self.close(Ok(None)) + } +} - fn poll_frame( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - match Pin::new(&mut self.as_mut().incoming).poll_frame(cx) { - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Err(err))) => { - Poll::Ready(Some(Err(ErrorCode::from_hyper_response_error(err)))) - } - Poll::Ready(Some(Ok(frame))) => { - self.timeout.reset(); - Poll::Ready(Some(Ok(frame))) - } - Poll::Pending => { - ready!(self.timeout.poll_tick(cx)); - Poll::Ready(Some(Err(ErrorCode::ConnectionReadTimeout))) - } +impl HostBodyStreamProducer { + fn close(&mut self, res: Result>, ErrorCode>) { + if let Some(tx) = self.trailers.take() { + _ = tx.send(res); } } +} - fn is_end_stream(&self) -> bool { - self.incoming.is_end_stream() - } +impl StreamProducer for HostBodyStreamProducer +where + D: 'static, +{ + type Item = u8; + type Buffer = Cursor; - fn size_hint(&self) -> http_body::SizeHint { - self.incoming.size_hint() + fn poll_produce<'a>( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut store: StoreContextMut<'a, D>, + mut dst: Destination<'a, Self::Item, Self::Buffer>, + finish: bool, + ) -> Poll> { + let res = 'result: { + let cap = match dst.remaining(&mut store).map(NonZeroUsize::new) { + Some(Some(cap)) => Some(cap), + Some(None) => { + if self.body.is_end_stream() { + break 'result Ok(None); + } else { + return Poll::Ready(Ok(StreamResult::Completed)); + } + } + None => None, + }; + match Pin::new(&mut self.body).poll_frame(cx) { + Poll::Ready(Some(Ok(frame))) => { + match frame.into_data().map_err(http_body::Frame::into_trailers) { + Ok(mut frame) => { + if let Some(cap) = cap { + let n = frame.len(); + let cap = cap.into(); + if n > cap { + dst.set_buffer(Cursor::new(frame.split_off(cap))); + let mut dst = dst.as_direct(store, cap); + dst.remaining().copy_from_slice(&frame); + dst.mark_written(cap); + } else { + let mut dst = dst.as_direct(store, n); + dst.remaining()[..n].copy_from_slice(&frame); + dst.mark_written(n); + } + } else { + dst.set_buffer(Cursor::new(frame)); + } + return Poll::Ready(Ok(StreamResult::Completed)); + } + Err(Ok(trailers)) => { + let trailers = (self.getter)(store.data_mut()) + .table + .push(Fields::new_mutable(trailers)) + .context("failed to push trailers to table")?; + break 'result Ok(Some(trailers)); + } + Err(Err(..)) => break 'result Err(ErrorCode::HttpProtocolError), + } + } + Poll::Ready(Some(Err(err))) => break 'result Err(err), + Poll::Ready(None) => break 'result Ok(None), + Poll::Pending if finish => return Poll::Ready(Ok(StreamResult::Cancelled)), + Poll::Pending => return Poll::Pending, + } + }; + self.close(res); + Poll::Ready(Ok(StreamResult::Dropped)) } } diff --git a/crates/wasi-http/src/p3/conv.rs b/crates/wasi-http/src/p3/conv.rs index 64c5a385370d..ef4587fb7848 100644 --- a/crates/wasi-http/src/p3/conv.rs +++ b/crates/wasi-http/src/p3/conv.rs @@ -11,7 +11,7 @@ impl From for ErrorCode { impl ErrorCode { /// Translate a [`hyper::Error`] to a wasi-http [ErrorCode] in the context of a request. - pub fn from_hyper_request_error(err: hyper::Error) -> Self { + pub(crate) fn from_hyper_request_error(err: hyper::Error) -> Self { // If there's a source, we might be able to extract a wasi-http error from it. if let Some(cause) = err.source() { if let Some(err) = cause.downcast_ref::() { @@ -25,7 +25,7 @@ impl ErrorCode { } /// Translate a [`hyper::Error`] to a wasi-http [ErrorCode] in the context of a response. - pub fn from_hyper_response_error(err: hyper::Error) -> Self { + pub(crate) fn from_hyper_response_error(err: hyper::Error) -> Self { if err.is_timeout() { return ErrorCode::HttpResponseTimeout; } diff --git a/crates/wasi-http/src/p3/host/handler.rs b/crates/wasi-http/src/p3/host/handler.rs index 0c73bf52ce2e..a3d4bc6a6e48 100644 --- a/crates/wasi-http/src/p3/host/handler.rs +++ b/crates/wasi-http/src/p3/host/handler.rs @@ -1,15 +1,30 @@ use crate::p3::bindings::http::handler::{Host, HostWithStore}; use crate::p3::bindings::http::types::{ErrorCode, Request, Response}; -use crate::p3::body::{Body, ConsumedBody, GuestBody}; -use crate::p3::host::{delete_request, push_response}; -use crate::p3::{HttpError, HttpResult, WasiHttp, WasiHttpCtxView}; +use crate::p3::body::{Body, ConsumedBody, GuestBody, GuestBodyKind}; +use crate::p3::{HttpError, HttpResult, WasiHttp, WasiHttpCtxView, get_content_length}; +use anyhow::Context as _; +use core::pin::Pin; use http::header::HOST; use http::{HeaderValue, Uri}; use http_body_util::BodyExt as _; use std::sync::Arc; use tokio::sync::oneshot; use tracing::debug; -use wasmtime::component::{Accessor, Resource}; +use wasmtime::component::{Accessor, AccessorTask, Resource}; + +struct SendRequestTask { + io: Pin> + Send>>, + result_tx: oneshot::Sender>, +} + +impl AccessorTask> for SendRequestTask { + async fn run(self, _: &Accessor) -> wasmtime::Result<()> { + let res = self.io.await; + debug!(?res, "`send_request` I/O future finished"); + _ = self.result_tx.send(res); + Ok(()) + } +} impl HostWithStore for WasiHttp { async fn handle( @@ -17,9 +32,10 @@ impl HostWithStore for WasiHttp { req: Resource, ) -> HttpResult> { let getter = store.getter(); + let (io_result_tx, io_result_rx) = oneshot::channel(); let (res_result_tx, res_result_rx) = oneshot::channel(); - let (fut, req_result_tx) = store.with(|mut store| { - let WasiHttpCtxView { ctx, table } = store.get(); + let fut = store.with(|mut store| { + let WasiHttpCtxView { table, .. } = store.get(); let Request { method, scheme, @@ -28,8 +44,47 @@ impl HostWithStore for WasiHttp { headers, options, body, - } = delete_request(table, req).map_err(HttpError::trap)?; + } = table + .delete(req) + .context("failed to delete request from table") + .map_err(HttpError::trap)?; let mut headers = Arc::unwrap_or_clone(headers); + let body = match body { + Body::Guest { + contents_rx, + trailers_rx, + result_tx, + } => { + let (http_result_tx, http_result_rx) = oneshot::channel(); + let content_length = get_content_length(&headers) + .map_err(|err| ErrorCode::InternalError(Some(format!("{err:#}"))))?; + _ = result_tx.send(Box::new(async move { + if let Ok(Err(err)) = http_result_rx.await { + return Err(err); + }; + io_result_rx.await.unwrap_or(Ok(())) + })); + GuestBody::new( + &mut store, + contents_rx, + trailers_rx, + http_result_tx, + content_length, + GuestBodyKind::Request, + getter, + ) + .boxed() + } + Body::Host { body, result_tx } => { + _ = result_tx.send(Box::new( + async move { io_result_rx.await.unwrap_or(Ok(())) }, + )); + body + } + Body::Consumed => ConsumedBody.boxed(), + }; + + let WasiHttpCtxView { ctx, .. } = store.get(); if ctx.set_host_header() { let host = if let Some(authority) = authority.as_ref() { HeaderValue::try_from(authority.as_str()) @@ -39,7 +94,6 @@ impl HostWithStore for WasiHttp { }; headers.insert(HOST, host); } - let scheme = match scheme { None => ctx.default_scheme().ok_or(ErrorCode::HttpProtocolError)?, Some(scheme) if ctx.is_supported_scheme(&scheme) => scheme, @@ -56,48 +110,29 @@ impl HostWithStore for WasiHttp { debug!(?err, "failed to build request URI"); ErrorCode::HttpRequestUriInvalid })?; - let mut req = http::Request::builder(); *req.headers_mut().unwrap() = headers; - let (body, result_tx) = match body { - Body::Guest { - contents_rx, - trailers_rx, - result_tx, - } => ( - GuestBody::new(&mut store, contents_rx, trailers_rx, getter).boxed(), - Some(result_tx), - ), - Body::Host { body, result_tx } => (body, Some(result_tx)), - Body::Consumed => (ConsumedBody.boxed(), None), - }; let req = req .method(method) .uri(uri) .body(body) .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?; - HttpResult::Ok(( - store.get().ctx.send_request( - req, - options.as_deref().copied(), - Box::new(async { - let Ok(fut) = res_result_rx.await else { - return Ok(()); - }; - Box::into_pin(fut).await - }), - ), - result_tx, + HttpResult::Ok(store.get().ctx.send_request( + req, + options.as_deref().copied(), + Box::new(async { + let Ok(fut) = res_result_rx.await else { + return Ok(()); + }; + Box::into_pin(fut).await + }), )) })?; let (res, io) = Box::into_pin(fut).await?; - if let Some(req_result_tx) = req_result_tx { - if let Err(io) = req_result_tx.send(io) { - Box::into_pin(io).await?; - } - } else { - Box::into_pin(io).await?; - } + store.spawn(SendRequestTask { + io: Box::into_pin(io), + result_tx: io_result_tx, + }); let ( http::response::Parts { status, headers, .. @@ -112,7 +147,14 @@ impl HostWithStore for WasiHttp { result_tx: res_result_tx, }, }; - store.with(|mut store| push_response(store.get().table, res).map_err(HttpError::trap)) + store.with(|mut store| { + store + .get() + .table + .push(res) + .context("failed to push response to table") + .map_err(HttpError::trap) + }) } } diff --git a/crates/wasi-http/src/p3/host/mod.rs b/crates/wasi-http/src/p3/host/mod.rs index 702650a72865..9d6627020f26 100644 --- a/crates/wasi-http/src/p3/host/mod.rs +++ b/crates/wasi-http/src/p3/host/mod.rs @@ -1,89 +1,2 @@ -use crate::p3::bindings::http::types::{Fields, Request, Response}; -use anyhow::Context as _; -use wasmtime::component::{Resource, ResourceTable}; - mod handler; mod types; - -fn get_fields<'a>( - table: &'a ResourceTable, - fields: &Resource, -) -> wasmtime::Result<&'a Fields> { - table - .get(&fields) - .context("failed to get fields from table") -} - -fn get_fields_mut<'a>( - table: &'a mut ResourceTable, - fields: &Resource, -) -> wasmtime::Result<&'a mut Fields> { - table - .get_mut(&fields) - .context("failed to get fields from table") -} - -fn push_fields(table: &mut ResourceTable, fields: Fields) -> wasmtime::Result> { - table.push(fields).context("failed to push fields to table") -} - -fn delete_fields(table: &mut ResourceTable, fields: Resource) -> wasmtime::Result { - table - .delete(fields) - .context("failed to delete fields from table") -} - -fn get_request<'a>( - table: &'a ResourceTable, - req: &Resource, -) -> wasmtime::Result<&'a Request> { - table.get(req).context("failed to get request from table") -} - -fn get_request_mut<'a>( - table: &'a mut ResourceTable, - req: &Resource, -) -> wasmtime::Result<&'a mut Request> { - table - .get_mut(req) - .context("failed to get request from table") -} - -fn push_request(table: &mut ResourceTable, req: Request) -> wasmtime::Result> { - table.push(req).context("failed to push request to table") -} - -fn delete_request(table: &mut ResourceTable, req: Resource) -> wasmtime::Result { - table - .delete(req) - .context("failed to delete request from table") -} - -fn get_response<'a>( - table: &'a ResourceTable, - res: &Resource, -) -> wasmtime::Result<&'a Response> { - table.get(res).context("failed to get response from table") -} - -fn get_response_mut<'a>( - table: &'a mut ResourceTable, - res: &Resource, -) -> wasmtime::Result<&'a mut Response> { - table - .get_mut(res) - .context("failed to get response from table") -} - -fn push_response(table: &mut ResourceTable, res: Response) -> wasmtime::Result> { - table.push(res).context("failed to push response to table") -} - -fn delete_response( - table: &mut ResourceTable, - res: Resource, -) -> wasmtime::Result { - table - .delete(res) - .context("failed to delete response from table") -} diff --git a/crates/wasi-http/src/p3/host/types.rs b/crates/wasi-http/src/p3/host/types.rs index 5e76d3633b53..fab634a75c82 100644 --- a/crates/wasi-http/src/p3/host/types.rs +++ b/crates/wasi-http/src/p3/host/types.rs @@ -1,7 +1,3 @@ -use super::{ - delete_fields, delete_request, delete_response, get_fields, get_fields_mut, get_request, - get_request_mut, get_response, get_response_mut, push_fields, push_request, push_response, -}; use crate::p3::bindings::clocks::monotonic_clock::Duration; use crate::p3::bindings::http::types::{ ErrorCode, FieldName, FieldValue, Fields, HeaderError, Headers, Host, HostFields, HostRequest, @@ -9,27 +5,79 @@ use crate::p3::bindings::http::types::{ RequestOptions, RequestOptionsError, Response, Scheme, StatusCode, Trailers, }; use crate::p3::body::Body; -use crate::p3::{HttpError, WasiHttp, WasiHttpCtxView}; +use crate::p3::{HeaderResult, HttpError, RequestOptionsResult, WasiHttp, WasiHttpCtxView}; use anyhow::Context as _; -use bytes::Bytes; use core::mem; -use core::num::NonZeroUsize; use core::pin::Pin; -use core::task::Context; -use core::task::Poll; +use core::task::{Context, Poll}; use http::header::CONTENT_LENGTH; -use http_body::Body as _; -use http_body_util::combinators::BoxBody; -use std::io::Cursor; use std::sync::Arc; use tokio::sync::oneshot; use wasmtime::StoreContextMut; use wasmtime::component::{ - Accessor, Destination, FutureProducer, FutureReader, Resource, StreamProducer, StreamReader, - StreamResult, + Access, Accessor, FutureProducer, FutureReader, Resource, ResourceTable, StreamReader, }; -use wasmtime_wasi::ResourceTable; -use wasmtime_wasi::p3::{FutureOneshotProducer, StreamEmptyProducer}; + +fn get_fields<'a>( + table: &'a ResourceTable, + fields: &Resource, +) -> wasmtime::Result<&'a Fields> { + table + .get(&fields) + .context("failed to get fields from table") +} + +fn get_fields_mut<'a>( + table: &'a mut ResourceTable, + fields: &Resource, +) -> HeaderResult<&'a mut Fields> { + table + .get_mut(&fields) + .context("failed to get fields from table") + .map_err(crate::p3::HeaderError::trap) +} + +fn push_fields(table: &mut ResourceTable, fields: Fields) -> wasmtime::Result> { + table.push(fields).context("failed to push fields to table") +} + +fn delete_fields(table: &mut ResourceTable, fields: Resource) -> wasmtime::Result { + table + .delete(fields) + .context("failed to delete fields from table") +} + +fn get_request<'a>( + table: &'a ResourceTable, + req: &Resource, +) -> wasmtime::Result<&'a Request> { + table.get(req).context("failed to get request from table") +} + +fn get_request_mut<'a>( + table: &'a mut ResourceTable, + req: &Resource, +) -> wasmtime::Result<&'a mut Request> { + table + .get_mut(req) + .context("failed to get request from table") +} + +fn get_response<'a>( + table: &'a ResourceTable, + res: &Resource, +) -> wasmtime::Result<&'a Response> { + table.get(res).context("failed to get response from table") +} + +fn get_response_mut<'a>( + table: &'a mut ResourceTable, + res: &Resource, +) -> wasmtime::Result<&'a mut Response> { + table + .get_mut(res) + .context("failed to get response from table") +} fn get_request_options<'a>( table: &'a ResourceTable, @@ -43,10 +91,11 @@ fn get_request_options<'a>( fn get_request_options_mut<'a>( table: &'a mut ResourceTable, opts: &Resource, -) -> wasmtime::Result<&'a mut RequestOptions> { +) -> RequestOptionsResult<&'a mut RequestOptions> { table .get_mut(opts) .context("failed to get request options from table") + .map_err(crate::p3::RequestOptionsError::trap) } fn push_request_options( @@ -112,95 +161,6 @@ impl FutureProducer for GuestBodyResultProducer { } } -struct HostBodyStreamProducer { - body: BoxBody, - trailers: Option>, ErrorCode>>>, - getter: for<'a> fn(&'a mut T) -> WasiHttpCtxView<'a>, -} - -impl Drop for HostBodyStreamProducer { - fn drop(&mut self) { - self.close(Ok(None)) - } -} - -impl HostBodyStreamProducer { - fn close(&mut self, res: Result>, ErrorCode>) { - if let Some(tx) = self.trailers.take() { - _ = tx.send(res); - } - } -} - -impl StreamProducer for HostBodyStreamProducer -where - D: 'static, -{ - type Item = u8; - type Buffer = Cursor; - - fn poll_produce<'a>( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut store: StoreContextMut<'a, D>, - mut dst: Destination<'a, Self::Item, Self::Buffer>, - finish: bool, - ) -> Poll> { - let res = 'result: { - let cap = match dst.remaining(&mut store).map(NonZeroUsize::new) { - Some(Some(cap)) => Some(cap), - Some(None) => { - if self.body.is_end_stream() { - break 'result Ok(None); - } else { - return Poll::Ready(Ok(StreamResult::Completed)); - } - } - None => None, - }; - match Pin::new(&mut self.body).poll_frame(cx) { - Poll::Ready(Some(Ok(frame))) => { - match frame.into_data().map_err(http_body::Frame::into_trailers) { - Ok(mut frame) => { - if let Some(cap) = cap { - let n = frame.len(); - let cap = cap.into(); - if n > cap { - dst.set_buffer(Cursor::new(frame.split_off(cap))); - let mut dst = dst.as_direct(store, cap); - dst.remaining().copy_from_slice(&frame); - dst.mark_written(cap); - } else { - let mut dst = dst.as_direct(store, n); - dst.remaining()[..n].copy_from_slice(&frame); - dst.mark_written(n); - } - } else { - dst.set_buffer(Cursor::new(frame)); - } - return Poll::Ready(Ok(StreamResult::Completed)); - } - Err(Ok(trailers)) => { - let trailers = push_fields( - (self.getter)(store.data_mut()).table, - Fields::new_mutable(trailers), - )?; - break 'result Ok(Some(trailers)); - } - Err(Err(..)) => break 'result Err(ErrorCode::HttpProtocolError), - } - } - Poll::Ready(Some(Err(err))) => break 'result Err(err), - Poll::Ready(None) => break 'result Ok(None), - Poll::Pending if finish => return Poll::Ready(Ok(StreamResult::Cancelled)), - Poll::Pending => return Poll::Pending, - } - }; - self.close(res); - Poll::Ready(Ok(StreamResult::Dropped)) - } -} - impl HostFields for WasiHttpCtxView<'_> { fn new(&mut self) -> wasmtime::Result> { push_fields(self.table, Fields::new_mutable_default()) @@ -209,24 +169,19 @@ impl HostFields for WasiHttpCtxView<'_> { fn from_list( &mut self, entries: Vec<(FieldName, FieldValue)>, - ) -> wasmtime::Result, HeaderError>> { + ) -> HeaderResult> { let mut fields = http::HeaderMap::default(); for (name, value) in entries { - let Ok(name) = name.parse() else { - return Ok(Err(HeaderError::InvalidSyntax)); - }; + let name = name.parse().or(Err(HeaderError::InvalidSyntax))?; if self.ctx.is_forbidden_header(&name) { - return Ok(Err(HeaderError::Forbidden)); - } - match parse_header_value(&name, value) { - Ok(value) => { - fields.append(name, value); - } - Err(err) => return Ok(Err(err)), + return Err(HeaderError::Forbidden.into()); } + let value = parse_header_value(&name, value)?; + fields.append(name, value); } - let fields = push_fields(self.table, Fields::new_mutable(fields))?; - Ok(Ok(fields)) + let fields = push_fields(self.table, Fields::new_mutable(fields)) + .map_err(crate::p3::HeaderError::trap)?; + Ok(fields) } fn get( @@ -252,73 +207,52 @@ impl HostFields for WasiHttpCtxView<'_> { fields: Resource, name: FieldName, value: Vec, - ) -> wasmtime::Result> { - let Ok(name) = name.parse() else { - return Ok(Err(HeaderError::InvalidSyntax)); - }; + ) -> HeaderResult<()> { + let name = name.parse().or(Err(HeaderError::InvalidSyntax))?; if self.ctx.is_forbidden_header(&name) { - return Ok(Err(HeaderError::Forbidden)); + return Err(HeaderError::Forbidden.into()); } let mut values = Vec::with_capacity(value.len()); for value in value { - match parse_header_value(&name, value) { - Ok(value) => { - values.push(value); - } - Err(err) => return Ok(Err(err)), - } + let value = parse_header_value(&name, value)?; + values.push(value); } let fields = get_fields_mut(self.table, &fields)?; - let Some(fields) = fields.get_mut() else { - return Ok(Err(HeaderError::Immutable)); - }; + let fields = fields.get_mut().ok_or(HeaderError::Immutable)?; fields.remove(&name); for value in values { fields.append(&name, value); } - Ok(Ok(())) + Ok(()) } - fn delete( - &mut self, - fields: Resource, - name: FieldName, - ) -> wasmtime::Result> { - let header = match http::HeaderName::from_bytes(name.as_bytes()) { - Ok(header) => header, - Err(_) => return Ok(Err(HeaderError::InvalidSyntax)), - }; - if self.ctx.is_forbidden_header(&header) { - return Ok(Err(HeaderError::Forbidden)); + fn delete(&mut self, fields: Resource, name: FieldName) -> HeaderResult<()> { + let name = name.parse().or(Err(HeaderError::InvalidSyntax))?; + if self.ctx.is_forbidden_header(&name) { + return Err(HeaderError::Forbidden.into()); } let fields = get_fields_mut(self.table, &fields)?; - let Some(fields) = fields.get_mut() else { - return Ok(Err(HeaderError::Immutable)); - }; + let fields = fields.get_mut().ok_or(HeaderError::Immutable)?; fields.remove(&name); - Ok(Ok(())) + Ok(()) } fn get_and_delete( &mut self, fields: Resource, name: FieldName, - ) -> wasmtime::Result, HeaderError>> { - let Ok(header) = http::header::HeaderName::from_bytes(name.as_bytes()) else { - return Ok(Err(HeaderError::InvalidSyntax)); - }; - if self.ctx.is_forbidden_header(&header) { - return Ok(Err(HeaderError::Forbidden)); + ) -> HeaderResult> { + let name = name.parse().or(Err(HeaderError::InvalidSyntax))?; + if self.ctx.is_forbidden_header(&name) { + return Err(HeaderError::Forbidden.into()); } let fields = get_fields_mut(self.table, &fields)?; - let Some(fields) = fields.get_mut() else { - return Ok(Err(HeaderError::Immutable)); - }; - let http::header::Entry::Occupied(entry) = fields.entry(header) else { - return Ok(Ok(vec![])); + let fields = fields.get_mut().ok_or(HeaderError::Immutable)?; + let http::header::Entry::Occupied(entry) = fields.entry(name) else { + return Ok(Vec::default()); }; let (.., values) = entry.remove_entry_mult(); - Ok(Ok(values.map(|header| header.as_bytes().into()).collect())) + Ok(values.map(|value| value.as_bytes().into()).collect()) } fn append( @@ -326,23 +260,16 @@ impl HostFields for WasiHttpCtxView<'_> { fields: Resource, name: FieldName, value: FieldValue, - ) -> wasmtime::Result> { - let Ok(name) = name.parse() else { - return Ok(Err(HeaderError::InvalidSyntax)); - }; + ) -> HeaderResult<()> { + let name = name.parse().or(Err(HeaderError::InvalidSyntax))?; if self.ctx.is_forbidden_header(&name) { - return Ok(Err(HeaderError::Forbidden)); + return Err(HeaderError::Forbidden.into()); } - let value = match parse_header_value(&name, value) { - Ok(value) => value, - Err(err) => return Ok(Err(err)), - }; + let value = parse_header_value(&name, value)?; let fields = get_fields_mut(self.table, &fields)?; - let Some(fields) = fields.get_mut() else { - return Ok(Err(HeaderError::Immutable)); - }; + let fields = fields.get_mut().ok_or(HeaderError::Immutable)?; fields.append(name, value); - Ok(Ok(())) + Ok(()) } fn copy_all( @@ -381,6 +308,7 @@ impl HostRequestWithStore for WasiHttp { let (result_tx, result_rx) = oneshot::channel(); let WasiHttpCtxView { table, .. } = store.get(); let headers = delete_fields(table, headers)?; + // `Content-Length` header value is validated in `fields` implementation let options = options .map(|options| delete_request_options(table, options)) .transpose()?; @@ -398,7 +326,7 @@ impl HostRequestWithStore for WasiHttp { options: options.map(Into::into), body, }; - let req = push_request(table, req)?; + let req = table.push(req).context("failed to push request to table")?; Ok(( req, FutureReader::new( @@ -424,59 +352,21 @@ impl HostRequestWithStore for WasiHttp { > { let getter = store.getter(); store.with(|mut store| { - let req = get_request_mut(store.get().table, &req)?; - match mem::replace(&mut req.body, Body::Consumed) { - Body::Guest { - contents_rx: Some(contents_rx), - trailers_rx, - result_tx, - } => { - // TODO: Use a result specified by the caller - // https://github.com/WebAssembly/wasi-http/issues/176 - _ = result_tx.send(Box::new(async { Ok(()) })); - Ok(Ok((contents_rx, trailers_rx))) - } - Body::Guest { - contents_rx: None, - trailers_rx, - result_tx, - } => { - let instance = store.instance(); - // TODO: Use a result specified by the caller - // https://github.com/WebAssembly/wasi-http/issues/176 - _ = result_tx.send(Box::new(async { Ok(()) })); - Ok(Ok(( - StreamReader::new(instance, &mut store, StreamEmptyProducer::default()), - trailers_rx, - ))) - } - Body::Host { body, result_tx } => { - let instance = store.instance(); - // TODO: Use a result specified by the caller - // https://github.com/WebAssembly/wasi-http/issues/176 - _ = result_tx.send(Box::new(async { Ok(()) })); - let (trailers_tx, trailers_rx) = oneshot::channel(); - Ok(Ok(( - StreamReader::new( - instance, - &mut store, - HostBodyStreamProducer { - body, - trailers: Some(trailers_tx), - getter, - }, - ), - FutureReader::new( - instance, - &mut store, - FutureOneshotProducer::from(trailers_rx), - ), - ))) - } - Body::Consumed => Ok(Err(())), - } + let Request { body, .. } = get_request_mut(store.get().table, &req)?; + let body = mem::replace(body, Body::Consumed); + Ok(body.consume(store, getter)) }) } + + fn drop(mut store: Access<'_, T, Self>, req: Resource) -> wasmtime::Result<()> { + let Request { body, .. } = store + .get() + .table + .delete(req) + .context("failed to delete request from table")?; + body.drop(store); + Ok(()) + } } impl HostRequest for WasiHttpCtxView<'_> { @@ -590,11 +480,6 @@ impl HostRequest for WasiHttpCtxView<'_> { let Request { headers, .. } = get_request(self.table, &req)?; push_fields(self.table, Fields::new_immutable(Arc::clone(headers))) } - - fn drop(&mut self, req: Resource) -> wasmtime::Result<()> { - delete_request(self.table, req)?; - Ok(()) - } } impl HostRequestOptions for WasiHttpCtxView<'_> { @@ -621,13 +506,11 @@ impl HostRequestOptions for WasiHttpCtxView<'_> { &mut self, opts: Resource, duration: Option, - ) -> wasmtime::Result> { + ) -> RequestOptionsResult<()> { let opts = get_request_options_mut(self.table, &opts)?; - let Some(opts) = opts.get_mut() else { - return Ok(Err(RequestOptionsError::Immutable)); - }; + let opts = opts.get_mut().ok_or(RequestOptionsError::Immutable)?; opts.connect_timeout = duration.map(core::time::Duration::from_nanos); - Ok(Ok(())) + Ok(()) } fn get_first_byte_timeout( @@ -649,13 +532,11 @@ impl HostRequestOptions for WasiHttpCtxView<'_> { &mut self, opts: Resource, duration: Option, - ) -> wasmtime::Result> { + ) -> RequestOptionsResult<()> { let opts = get_request_options_mut(self.table, &opts)?; - let Some(opts) = opts.get_mut() else { - return Ok(Err(RequestOptionsError::Immutable)); - }; + let opts = opts.get_mut().ok_or(RequestOptionsError::Immutable)?; opts.first_byte_timeout = duration.map(core::time::Duration::from_nanos); - Ok(Ok(())) + Ok(()) } fn get_between_bytes_timeout( @@ -677,13 +558,11 @@ impl HostRequestOptions for WasiHttpCtxView<'_> { &mut self, opts: Resource, duration: Option, - ) -> wasmtime::Result> { + ) -> RequestOptionsResult<()> { let opts = get_request_options_mut(self.table, &opts)?; - let Some(opts) = opts.get_mut() else { - return Ok(Err(RequestOptionsError::Immutable)); - }; + let opts = opts.get_mut().ok_or(RequestOptionsError::Immutable)?; opts.between_bytes_timeout = duration.map(core::time::Duration::from_nanos); - Ok(Ok(())) + Ok(()) } fn clone( @@ -722,7 +601,9 @@ impl HostResponseWithStore for WasiHttp { headers: headers.into(), body, }; - let res = push_response(table, res)?; + let res = table + .push(res) + .context("failed to push response to table")?; Ok(( res, FutureReader::new( @@ -748,59 +629,21 @@ impl HostResponseWithStore for WasiHttp { > { let getter = store.getter(); store.with(|mut store| { - let res = get_response_mut(store.get().table, &res)?; - match mem::replace(&mut res.body, Body::Consumed) { - Body::Guest { - contents_rx: Some(contents_rx), - trailers_rx, - result_tx, - } => { - // TODO: Use a result specified by the caller - // https://github.com/WebAssembly/wasi-http/issues/176 - _ = result_tx.send(Box::new(async { Ok(()) })); - Ok(Ok((contents_rx, trailers_rx))) - } - Body::Guest { - contents_rx: None, - trailers_rx, - result_tx, - } => { - let instance = store.instance(); - // TODO: Use a result specified by the caller - // https://github.com/WebAssembly/wasi-http/issues/176 - _ = result_tx.send(Box::new(async { Ok(()) })); - Ok(Ok(( - StreamReader::new(instance, &mut store, StreamEmptyProducer::default()), - trailers_rx, - ))) - } - Body::Host { body, result_tx } => { - let instance = store.instance(); - // TODO: Use a result specified by the caller - // https://github.com/WebAssembly/wasi-http/issues/176 - _ = result_tx.send(Box::new(async { Ok(()) })); - let (trailers_tx, trailers_rx) = oneshot::channel(); - Ok(Ok(( - StreamReader::new( - instance, - &mut store, - HostBodyStreamProducer { - body, - trailers: Some(trailers_tx), - getter, - }, - ), - FutureReader::new( - instance, - &mut store, - FutureOneshotProducer::from(trailers_rx), - ), - ))) - } - Body::Consumed => Ok(Err(())), - } + let Response { body, .. } = get_response_mut(store.get().table, &res)?; + let body = mem::replace(body, Body::Consumed); + Ok(body.consume(store, getter)) }) } + + fn drop(mut store: Access<'_, T, Self>, res: Resource) -> wasmtime::Result<()> { + let Response { body, .. } = store + .get() + .table + .delete(res) + .context("failed to delete response from table")?; + body.drop(store); + Ok(()) + } } impl HostResponse for WasiHttpCtxView<'_> { @@ -826,15 +669,24 @@ impl HostResponse for WasiHttpCtxView<'_> { let Response { headers, .. } = get_response(self.table, &res)?; push_fields(self.table, Fields::new_immutable(Arc::clone(headers))) } - - fn drop(&mut self, res: Resource) -> wasmtime::Result<()> { - delete_response(self.table, res)?; - Ok(()) - } } impl Host for WasiHttpCtxView<'_> { fn convert_error_code(&mut self, error: HttpError) -> wasmtime::Result { error.downcast() } + + fn convert_header_error( + &mut self, + error: crate::p3::HeaderError, + ) -> wasmtime::Result { + error.downcast() + } + + fn convert_request_options_error( + &mut self, + error: crate::p3::RequestOptionsError, + ) -> wasmtime::Result { + error.downcast() + } } diff --git a/crates/wasi-http/src/p3/mod.rs b/crates/wasi-http/src/p3/mod.rs index a0d4ce3f0471..d8a25e73b448 100644 --- a/crates/wasi-http/src/p3/mod.rs +++ b/crates/wasi-http/src/p3/mod.rs @@ -9,7 +9,7 @@ //! Documentation of this module may be incorrect or out-of-sync with the implementation. pub mod bindings; -pub mod body; +mod body; mod conv; mod host; mod proxy; @@ -27,14 +27,33 @@ use bindings::http::{handler, types}; use bytes::Bytes; use core::ops::Deref; use http::HeaderName; +use http::header::CONTENT_LENGTH; use http::uri::Scheme; use http_body_util::combinators::BoxBody; use std::sync::Arc; use wasmtime::component::{HasData, Linker, ResourceTable}; use wasmtime_wasi::TrappableError; -pub type HttpResult = Result; -pub type HttpError = TrappableError; +pub(crate) type HttpResult = Result; +pub(crate) type HttpError = TrappableError; + +pub(crate) type HeaderResult = Result; +pub(crate) type HeaderError = TrappableError; + +pub(crate) type RequestOptionsResult = Result; +pub(crate) type RequestOptionsError = TrappableError; + +/// Extract the `Content-Length` header value from a [`http::HeaderMap`], returning `None` if it's not +/// present. This function will return `Err` if it's not possible to parse the `Content-Length` +/// header. +fn get_content_length(headers: &http::HeaderMap) -> wasmtime::Result> { + let Some(v) = headers.get(CONTENT_LENGTH) else { + return Ok(None); + }; + let v = v.to_str()?; + let v = v.parse()?; + Ok(Some(v)) +} pub(crate) struct WasiHttp; @@ -42,6 +61,7 @@ impl HasData for WasiHttp { type Data<'a> = WasiHttpCtxView<'a>; } +/// A trait which provides internal WASI HTTP state. pub trait WasiHttpCtx: Send { /// Whether a given header should be considered forbidden and not allowed. fn is_forbidden_header(&mut self, name: &HeaderName) -> bool { @@ -112,6 +132,7 @@ pub trait WasiHttpCtx: Send { >; } +/// Default implementation of [WasiHttpCtx]. #[cfg(feature = "default-send-request")] #[derive(Clone, Default)] pub struct DefaultWasiHttpCtx; @@ -119,12 +140,18 @@ pub struct DefaultWasiHttpCtx; #[cfg(feature = "default-send-request")] impl WasiHttpCtx for DefaultWasiHttpCtx {} +/// View into [WasiHttpCtx] implementation and [ResourceTable]. pub struct WasiHttpCtxView<'a> { + /// Mutable reference to the WASI HTTP context. pub ctx: &'a mut dyn WasiHttpCtx, + + /// Mutable reference to table used to manage resources. pub table: &'a mut ResourceTable, } +/// A trait which provides internal WASI HTTP state. pub trait WasiHttpView: Send { + /// Return a [WasiHttpCtxView] from mutable reference to self. fn http(&mut self) -> WasiHttpCtxView<'_>; } @@ -186,8 +213,13 @@ where } /// An [Arc], which may be immutable. +/// +/// In `wasi:http` resources like `fields` or `request-options` may be +/// mutable or immutable. This construct is used to model them efficiently. pub enum MaybeMutable { + /// Clone-on-write, mutable [Arc] Mutable(Arc), + /// Immutable [Arc] Immutable(Arc), } @@ -201,15 +233,19 @@ impl Deref for MaybeMutable { type Target = Arc; fn deref(&self) -> &Self::Target { - self.as_arc() + match self { + Self::Mutable(v) | Self::Immutable(v) => v, + } } } impl MaybeMutable { + /// Construct a mutable [`MaybeMutable`]. pub fn new_mutable(v: impl Into>) -> Self { Self::Mutable(v.into()) } + /// Construct a mutable [`MaybeMutable`] filling it with default `T`. pub fn new_mutable_default() -> Self where T: Default, @@ -217,26 +253,23 @@ impl MaybeMutable { Self::new_mutable(T::default()) } + /// Construct an immutable [`MaybeMutable`]. pub fn new_immutable(v: impl Into>) -> Self { Self::Immutable(v.into()) } - fn as_arc(&self) -> &Arc { + /// Unwrap [`MaybeMutable`] into [`Arc`]. + pub fn into_arc(self) -> Arc { match self { Self::Mutable(v) | Self::Immutable(v) => v, } } - fn into_arc(self) -> Arc { - match self { - Self::Mutable(v) | Self::Immutable(v) => v, - } - } - - pub fn get(&self) -> &T { - self - } - + /// If this [`MaybeMutable`] is [`Mutable`](MaybeMutable::Mutable), + /// return a mutable reference to it, otherwise return `None`. + /// + /// Internally, this will use [`Arc::make_mut`] and will clone the underlying + /// value, if multiple strong references to the inner [`Arc`] exist. pub fn get_mut(&mut self) -> Option<&mut T> where T: Clone, diff --git a/crates/wasi-http/src/p3/request.rs b/crates/wasi-http/src/p3/request.rs index 413d5fae836c..90c041cce86e 100644 --- a/crates/wasi-http/src/p3/request.rs +++ b/crates/wasi-http/src/p3/request.rs @@ -9,6 +9,7 @@ use http_body_util::combinators::BoxBody; use std::sync::Arc; use tokio::sync::oneshot; +/// The concrete type behind a `wasi:http/types/request-options` resource. #[derive(Copy, Clone, Debug, Default)] pub struct RequestOptions { /// How long to wait for a connection to be established. @@ -39,6 +40,9 @@ pub struct Request { impl Request { /// Construct a new [Request] + /// + /// This returns a [Future] that the guest will use to communicate + /// a request processing error, if any. pub fn new( method: Method, scheme: Option, @@ -73,6 +77,9 @@ impl Request { } /// Construct a new [Request] from [http::Request]. + /// + /// This returns a [Future] that the guest will use to communicate + /// a request processing error, if any. pub fn from_http( req: http::Request, ) -> ( @@ -112,7 +119,7 @@ impl Request { /// The default implementation of how an outgoing request is sent. /// -/// This implementation is used by the `wasi:http/outgoing-handler` interface +/// This implementation is used by the `wasi:http/handler` interface /// default implementation. #[cfg(feature = "default-send-request")] pub async fn default_send_request( @@ -248,7 +255,44 @@ pub async fn default_send_request( .expect("comes from valid request"); let send = async move { - use crate::p3::body::IncomingResponseBody; + use core::task::Context; + + /// Wrapper around [hyper::body::Incoming] used to + /// account for request option timeout configuration + struct IncomingResponseBody { + incoming: hyper::body::Incoming, + timeout: tokio::time::Interval, + } + impl http_body::Body for IncomingResponseBody { + type Data = ::Data; + type Error = ErrorCode; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + match Pin::new(&mut self.as_mut().incoming).poll_frame(cx) { + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Err(err))) => { + Poll::Ready(Some(Err(ErrorCode::from_hyper_response_error(err)))) + } + Poll::Ready(Some(Ok(frame))) => { + self.timeout.reset(); + Poll::Ready(Some(Ok(frame))) + } + Poll::Pending => { + ready!(self.timeout.poll_tick(cx)); + Poll::Ready(Some(Err(ErrorCode::ConnectionReadTimeout))) + } + } + } + fn is_end_stream(&self) -> bool { + self.incoming.is_end_stream() + } + fn size_hint(&self) -> http_body::SizeHint { + self.incoming.size_hint() + } + } let res = tokio::time::timeout(first_byte_timeout, sender.send_request(req)) .await @@ -260,28 +304,31 @@ pub async fn default_send_request( }; let mut send = pin!(send); let mut conn = Some(conn); + // Wait for response while driving connection I/O let res = poll_fn(|cx| match send.as_mut().poll(cx) { Poll::Ready(Ok(res)) => Poll::Ready(Ok(res)), Poll::Ready(Err(err)) => Poll::Ready(Err(err)), Poll::Pending => { - if let Some(fut) = conn.as_mut() { - let res = ready!(Pin::new(fut).poll(cx)); - conn = None; - match res { - Ok(()) => match ready!(send.as_mut().poll(cx)) { - Ok(res) => Poll::Ready(Ok(res)), - Err(err) => Poll::Ready(Err(err)), - }, - Err(err) => Poll::Ready(Err(ErrorCode::from_hyper_request_error(err))), - } - } else { - Poll::Pending + // Response is not ready, poll `hyper` connection to drive I/O if it has not completed yet + let Some(fut) = conn.as_mut() else { + // `hyper` connection already completed + return Poll::Pending; + }; + let res = ready!(Pin::new(fut).poll(cx)); + // `hyper` connection completed, record that to prevent repeated poll + conn = None; + match res { + // `hyper` connection has successfully completed, optimistically poll for response + Ok(()) => send.as_mut().poll(cx), + // `hyper` connection has failed, return the error + Err(err) => Poll::Ready(Err(ErrorCode::from_hyper_request_error(err))), } } }) .await?; Ok((res, async move { let Some(conn) = conn.take() else { + // `hyper` connection has already completed return Ok(()); }; conn.await.map_err(ErrorCode::from_hyper_response_error) diff --git a/crates/wasi-http/src/p3/response.rs b/crates/wasi-http/src/p3/response.rs index d2549542a568..fc48d80ea2c7 100644 --- a/crates/wasi-http/src/p3/response.rs +++ b/crates/wasi-http/src/p3/response.rs @@ -1,6 +1,7 @@ -use crate::p3::WasiHttpView; use crate::p3::bindings::http::types::ErrorCode; -use crate::p3::body::{Body, ConsumedBody, GuestBody}; +use crate::p3::body::{Body, ConsumedBody, GuestBody, GuestBodyKind}; +use crate::p3::{WasiHttpView, get_content_length}; +use anyhow::Context as _; use bytes::Bytes; use http::{HeaderMap, StatusCode}; use http_body_util::BodyExt as _; @@ -36,48 +37,42 @@ impl TryFrom for http::Response { } impl Response { - /// Construct a new [Response] - pub fn new( - status: StatusCode, - headers: impl Into>, - body: impl Into>, - ) -> ( - Self, - impl Future> + Send + 'static, - ) { - let (tx, rx) = oneshot::channel(); - ( - Self { - status, - headers: headers.into(), - body: Body::Host { - body: body.into(), - result_tx: tx, - }, - }, - async { - let Ok(fut) = rx.await else { return Ok(()) }; - Box::into_pin(fut).await - }, - ) - } - /// Convert [Response] into [http::Response]. + /// + /// The specified [Future] `fut` can be used to communicate + /// a response processing error, if any, to the guest. pub fn into_http( self, store: impl AsContextMut, fut: impl Future> + Send + 'static, - ) -> http::Result>> { - let response = http::Response::try_from(self)?; - let (response, body) = response.into_parts(); + ) -> wasmtime::Result>> { + let res = http::Response::try_from(self)?; + let (res, body) = res.into_parts(); let body = match body { Body::Guest { contents_rx, trailers_rx, result_tx, } => { - _ = result_tx.send(Box::new(fut)); - GuestBody::new(store, contents_rx, trailers_rx, T::http).boxed() + let (http_result_tx, http_result_rx) = oneshot::channel(); + let content_length = + get_content_length(&res.headers).context("failed to parse `content-length`")?; + _ = result_tx.send(Box::new(async move { + if let Ok(Err(err)) = http_result_rx.await { + return Err(err); + }; + fut.await + })); + GuestBody::new( + store, + contents_rx, + trailers_rx, + http_result_tx, + content_length, + GuestBodyKind::Response, + T::http, + ) + .boxed() } Body::Host { body, result_tx } => { _ = result_tx.send(Box::new(fut)); @@ -85,6 +80,6 @@ impl Response { } Body::Consumed => ConsumedBody.boxed(), }; - Ok(http::Response::from_parts(response, body)) + Ok(http::Response::from_parts(res, body)) } } diff --git a/crates/wasi-http/tests/all/http_server.rs b/crates/wasi-http/tests/all/http_server.rs index 55a61bb98916..0f90abeb5788 100644 --- a/crates/wasi-http/tests/all/http_server.rs +++ b/crates/wasi-http/tests/all/http_server.rs @@ -1,29 +1,30 @@ use anyhow::{Context, Result}; -use http_body_util::{BodyExt, Full, combinators::BoxBody}; -use hyper::{Request, Response, body::Bytes, service::service_fn}; -use std::{ - future::Future, - net::{SocketAddr, TcpStream}, - thread::JoinHandle, -}; +use http::header::CONTENT_LENGTH; +use hyper::service::service_fn; +use hyper::{Request, Response}; +use std::future::Future; +use std::net::{SocketAddr, TcpStream}; +use std::thread::JoinHandle; use tokio::net::TcpListener; use tracing::{debug, trace, warn}; use wasmtime_wasi_http::io::TokioIo; async fn test( - mut req: Request, -) -> http::Result>> { - debug!("preparing mocked response",); + req: Request, +) -> http::Result> { + debug!(?req, "preparing mocked response for request"); let method = req.method().to_string(); - let body = req.body_mut().collect().await.unwrap(); - let buf = body.to_bytes(); - trace!("hyper request body size {:?}", buf.len()); - - Response::builder() - .status(http::StatusCode::OK) + let uri = req.uri().to_string(); + let resp = Response::builder() .header("x-wasmtime-test-method", method) - .header("x-wasmtime-test-uri", req.uri().to_string()) - .body(Full::::from(buf).boxed()) + .header("x-wasmtime-test-uri", uri); + let resp = if let Some(content_length) = req.headers().get(CONTENT_LENGTH) { + resp.header(CONTENT_LENGTH, content_length) + } else { + resp + }; + let body = req.into_body(); + resp.body(body) } pub struct Server { diff --git a/crates/wasi-http/tests/all/p3/mod.rs b/crates/wasi-http/tests/all/p3/mod.rs index 09cbe200bb3a..d87b6dadfd71 100644 --- a/crates/wasi-http/tests/all/p3/mod.rs +++ b/crates/wasi-http/tests/all/p3/mod.rs @@ -155,21 +155,18 @@ async fn p3_http_outbound_request_timeout() -> anyhow::Result<()> { run_cli(P3_HTTP_OUTBOUND_REQUEST_TIMEOUT_COMPONENT, &server).await } -#[ignore = "unimplemented"] // TODO: implement #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn p3_http_outbound_request_post() -> anyhow::Result<()> { let server = Server::http1(1)?; run_cli(P3_HTTP_OUTBOUND_REQUEST_POST_COMPONENT, &server).await } -#[ignore = "unimplemented"] // TODO: implement #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn p3_http_outbound_request_large_post() -> anyhow::Result<()> { let server = Server::http1(1)?; run_cli(P3_HTTP_OUTBOUND_REQUEST_LARGE_POST_COMPONENT, &server).await } -#[ignore = "unimplemented"] // TODO: implement #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn p3_http_outbound_request_put() -> anyhow::Result<()> { let server = Server::http1(1)?; @@ -216,14 +213,12 @@ async fn p3_http_outbound_request_invalid_dnsname() -> anyhow::Result<()> { run_cli(P3_HTTP_OUTBOUND_REQUEST_INVALID_DNSNAME_COMPONENT, &server).await } -#[ignore = "unimplemented"] // TODO: implement #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn p3_http_outbound_request_response_build() -> anyhow::Result<()> { let server = Server::http1(1)?; run_cli(P3_HTTP_OUTBOUND_REQUEST_RESPONSE_BUILD_COMPONENT, &server).await } -#[ignore = "unimplemented"] // FIXME(#11631) #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn p3_http_outbound_request_content_length() -> anyhow::Result<()> { let server = Server::http1(3)?;