diff --git a/src/client/legacy/connect/mod.rs b/src/client/legacy/connect/mod.rs index 3309dbb..90a9767 100644 --- a/src/client/legacy/connect/mod.rs +++ b/src/client/legacy/connect/mod.rs @@ -80,6 +80,8 @@ pub mod dns; #[cfg(feature = "tokio")] mod http; +pub mod proxy; + pub(crate) mod capture; pub use capture::{capture_connection, CaptureConnection}; diff --git a/src/client/legacy/connect/proxy/mod.rs b/src/client/legacy/connect/proxy/mod.rs new file mode 100644 index 0000000..b7a7c14 --- /dev/null +++ b/src/client/legacy/connect/proxy/mod.rs @@ -0,0 +1,5 @@ +//! Proxy helpers + +mod tunnel; + +pub use self::tunnel::Tunnel; diff --git a/src/client/legacy/connect/proxy/tunnel.rs b/src/client/legacy/connect/proxy/tunnel.rs new file mode 100644 index 0000000..4f8c515 --- /dev/null +++ b/src/client/legacy/connect/proxy/tunnel.rs @@ -0,0 +1,258 @@ +use std::error::Error as StdError; +use std::future::Future; +use std::marker::{PhantomData, Unpin}; +use std::pin::Pin; +use std::task::{self, Poll}; + +use http::{HeaderMap, HeaderValue, Uri}; +use hyper::rt::{Read, Write}; +use pin_project_lite::pin_project; +use tower_service::Service; + +/// Tunnel Proxy via HTTP CONNECT +/// +/// This is a connector that can be used by the `legacy::Client`. It wraps +/// another connector, and after getting an underlying connection, it creates +/// an HTTP CONNECT tunnel over it. +#[derive(Debug)] +pub struct Tunnel { + headers: Headers, + inner: C, + proxy_dst: Uri, +} + +#[derive(Clone, Debug)] +enum Headers { + Empty, + Auth(HeaderValue), + Extra(HeaderMap), +} + +#[derive(Debug)] +pub enum TunnelError { + ConnectFailed(Box), + Io(std::io::Error), + MissingHost, + ProxyAuthRequired, + ProxyHeadersTooLong, + TunnelUnexpectedEof, + TunnelUnsuccessful, +} + +pin_project! { + // Not publicly exported (so missing_docs doesn't trigger). + // + // We return this `Future` instead of the `Pin>` directly + // so that users don't rely on it fitting in a `Pin>` slot + // (and thus we can change the type in the future). + #[must_use = "futures do nothing unless polled"] + #[allow(missing_debug_implementations)] + pub struct Tunneling { + #[pin] + fut: BoxTunneling, + _marker: PhantomData, + } +} + +type BoxTunneling = Pin> + Send>>; + +impl Tunnel { + /// Create a new Tunnel service. + /// + /// This wraps an underlying connector, and stores the address of a + /// tunneling proxy server. + /// + /// A `Tunnel` can then be called with any destination. The `dst` passed to + /// `call` will not be used to create the underlying connection, but will + /// be used in an HTTP CONNECT request sent to the proxy destination. + pub fn new(proxy_dst: Uri, connector: C) -> Self { + Self { + headers: Headers::Empty, + inner: connector, + proxy_dst, + } + } + + /// Add `proxy-authorization` header value to the CONNECT request. + pub fn with_auth(mut self, mut auth: HeaderValue) -> Self { + // just in case the user forgot + auth.set_sensitive(true); + match self.headers { + Headers::Empty => { + self.headers = Headers::Auth(auth); + } + Headers::Auth(ref mut existing) => { + *existing = auth; + } + Headers::Extra(ref mut extra) => { + extra.insert(http::header::PROXY_AUTHORIZATION, auth); + } + } + + self + } + + /// Add extra headers to be sent with the CONNECT request. + /// + /// If existing headers have been set, these will be merged. + pub fn with_headers(mut self, mut headers: HeaderMap) -> Self { + match self.headers { + Headers::Empty => { + self.headers = Headers::Extra(headers); + } + Headers::Auth(auth) => { + headers + .entry(http::header::PROXY_AUTHORIZATION) + .or_insert(auth); + self.headers = Headers::Extra(headers); + } + Headers::Extra(ref mut extra) => { + extra.extend(headers); + } + } + + self + } +} + +impl Service for Tunnel +where + C: Service, + C::Future: Send + 'static, + C::Response: Read + Write + Unpin + Send + 'static, + C::Error: Into>, +{ + type Response = C::Response; + type Error = TunnelError; + type Future = Tunneling; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + futures_util::ready!(self.inner.poll_ready(cx)) + .map_err(|e| TunnelError::ConnectFailed(e.into()))?; + Poll::Ready(Ok(())) + } + + fn call(&mut self, dst: Uri) -> Self::Future { + let connecting = self.inner.call(self.proxy_dst.clone()); + let headers = self.headers.clone(); + + Tunneling { + fut: Box::pin(async move { + let conn = connecting + .await + .map_err(|e| TunnelError::ConnectFailed(e.into()))?; + tunnel( + conn, + dst.host().ok_or(TunnelError::MissingHost)?, + dst.port().map(|p| p.as_u16()).unwrap_or(443), + &headers, + ) + .await + }), + _marker: PhantomData, + } + } +} + +impl Future for Tunneling +where + F: Future>, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + self.project().fut.poll(cx) + } +} + +async fn tunnel(mut conn: T, host: &str, port: u16, headers: &Headers) -> Result +where + T: Read + Write + Unpin, +{ + let mut buf = format!( + "\ + CONNECT {host}:{port} HTTP/1.1\r\n\ + Host: {host}:{port}\r\n\ + " + ) + .into_bytes(); + + match headers { + Headers::Auth(auth) => { + buf.extend_from_slice(b"Proxy-Authorization: "); + buf.extend_from_slice(auth.as_bytes()); + buf.extend_from_slice(b"\r\n"); + } + Headers::Extra(extra) => { + for (name, value) in extra { + buf.extend_from_slice(name.as_str().as_bytes()); + buf.extend_from_slice(b": "); + buf.extend_from_slice(value.as_bytes()); + buf.extend_from_slice(b"\r\n"); + } + } + Headers::Empty => (), + } + + // headers end + buf.extend_from_slice(b"\r\n"); + + crate::rt::write_all(&mut conn, &buf) + .await + .map_err(TunnelError::Io)?; + + let mut buf = [0; 8192]; + let mut pos = 0; + + loop { + let n = crate::rt::read(&mut conn, &mut buf[pos..]) + .await + .map_err(TunnelError::Io)?; + + if n == 0 { + return Err(TunnelError::TunnelUnexpectedEof); + } + pos += n; + + let recvd = &buf[..pos]; + if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") { + if recvd.ends_with(b"\r\n\r\n") { + return Ok(conn); + } + if pos == buf.len() { + return Err(TunnelError::ProxyHeadersTooLong); + } + // else read more + } else if recvd.starts_with(b"HTTP/1.1 407") { + return Err(TunnelError::ProxyAuthRequired); + } else { + return Err(TunnelError::TunnelUnsuccessful); + } + } +} + +impl std::fmt::Display for TunnelError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("tunnel error: ")?; + + f.write_str(match self { + TunnelError::MissingHost => "missing destination host", + TunnelError::ProxyAuthRequired => "proxy authorization required", + TunnelError::ProxyHeadersTooLong => "proxy response headers too long", + TunnelError::TunnelUnexpectedEof => "unexpected end of file", + TunnelError::TunnelUnsuccessful => "unsuccessful", + TunnelError::ConnectFailed(_) => "failed to create underlying connection", + TunnelError::Io(_) => "io error establishing tunnel", + }) + } +} + +impl std::error::Error for TunnelError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + TunnelError::Io(ref e) => Some(e), + TunnelError::ConnectFailed(ref e) => Some(&**e), + _ => None, + } + } +} diff --git a/src/rt/io.rs b/src/rt/io.rs new file mode 100644 index 0000000..0ce3ea9 --- /dev/null +++ b/src/rt/io.rs @@ -0,0 +1,33 @@ +use std::marker::Unpin; +use std::pin::Pin; +use std::task::Poll; + +use futures_util::future; +use futures_util::ready; +use hyper::rt::{Read, ReadBuf, Write}; + +pub(crate) async fn read(io: &mut T, buf: &mut [u8]) -> Result +where + T: Read + Unpin, +{ + future::poll_fn(move |cx| { + let mut buf = ReadBuf::new(buf); + ready!(Pin::new(&mut *io).poll_read(cx, buf.unfilled()))?; + Poll::Ready(Ok(buf.filled().len())) + }) + .await +} + +pub(crate) async fn write_all(io: &mut T, buf: &[u8]) -> Result<(), std::io::Error> +where + T: Write + Unpin, +{ + let mut n = 0; + future::poll_fn(move |cx| { + while n < buf.len() { + n += ready!(Pin::new(&mut *io).poll_write(cx, &buf[n..])?); + } + Poll::Ready(Ok(())) + }) + .await +} diff --git a/src/rt/mod.rs b/src/rt/mod.rs index 3ed8628..71363cc 100644 --- a/src/rt/mod.rs +++ b/src/rt/mod.rs @@ -1,5 +1,10 @@ //! Runtime utilities +#[cfg(feature = "client-legacy")] +mod io; +#[cfg(feature = "client-legacy")] +pub(crate) use self::io::{read, write_all}; + #[cfg(feature = "tokio")] pub mod tokio; diff --git a/tests/proxy.rs b/tests/proxy.rs new file mode 100644 index 0000000..f828bc1 --- /dev/null +++ b/tests/proxy.rs @@ -0,0 +1,37 @@ +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tower_service::Service; + +use hyper_util::client::legacy::connect::{proxy::Tunnel, HttpConnector}; + +#[cfg(not(miri))] +#[tokio::test] +async fn test_tunnel_works() { + let tcp = TcpListener::bind("127.0.0.1:0").await.expect("bind"); + let addr = tcp.local_addr().expect("local_addr"); + + let proxy_dst = format!("http://{}", addr).parse().expect("uri"); + let mut connector = Tunnel::new(proxy_dst, HttpConnector::new()); + let t1 = tokio::spawn(async move { + let _conn = connector + .call("https://hyper.rs".parse().unwrap()) + .await + .expect("tunnel"); + }); + + let t2 = tokio::spawn(async move { + let (mut io, _) = tcp.accept().await.expect("accept"); + let mut buf = [0u8; 64]; + let n = io.read(&mut buf).await.expect("read 1"); + assert_eq!( + &buf[..n], + b"CONNECT hyper.rs:443 HTTP/1.1\r\nHost: hyper.rs:443\r\n\r\n" + ); + io.write_all(b"HTTP/1.1 200 OK\r\n\r\n") + .await + .expect("write 1"); + }); + + t1.await.expect("task 1"); + t2.await.expect("task 2"); +}