Skip to content

Commit a891979

Browse files
authored
feat(client): add proxy::Tunnel legacy util (#140)
1 parent c39da45 commit a891979

File tree

6 files changed

+340
-0
lines changed

6 files changed

+340
-0
lines changed

src/client/legacy/connect/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ pub mod dns;
8080
#[cfg(feature = "tokio")]
8181
mod http;
8282

83+
pub mod proxy;
84+
8385
pub(crate) mod capture;
8486
pub use capture::{capture_connection, CaptureConnection};
8587

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
//! Proxy helpers
2+
3+
mod tunnel;
4+
5+
pub use self::tunnel::Tunnel;
+258
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
use std::error::Error as StdError;
2+
use std::future::Future;
3+
use std::marker::{PhantomData, Unpin};
4+
use std::pin::Pin;
5+
use std::task::{self, Poll};
6+
7+
use http::{HeaderMap, HeaderValue, Uri};
8+
use hyper::rt::{Read, Write};
9+
use pin_project_lite::pin_project;
10+
use tower_service::Service;
11+
12+
/// Tunnel Proxy via HTTP CONNECT
13+
///
14+
/// This is a connector that can be used by the `legacy::Client`. It wraps
15+
/// another connector, and after getting an underlying connection, it creates
16+
/// an HTTP CONNECT tunnel over it.
17+
#[derive(Debug)]
18+
pub struct Tunnel<C> {
19+
headers: Headers,
20+
inner: C,
21+
proxy_dst: Uri,
22+
}
23+
24+
#[derive(Clone, Debug)]
25+
enum Headers {
26+
Empty,
27+
Auth(HeaderValue),
28+
Extra(HeaderMap),
29+
}
30+
31+
#[derive(Debug)]
32+
pub enum TunnelError {
33+
ConnectFailed(Box<dyn StdError + Send + Sync>),
34+
Io(std::io::Error),
35+
MissingHost,
36+
ProxyAuthRequired,
37+
ProxyHeadersTooLong,
38+
TunnelUnexpectedEof,
39+
TunnelUnsuccessful,
40+
}
41+
42+
pin_project! {
43+
// Not publicly exported (so missing_docs doesn't trigger).
44+
//
45+
// We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
46+
// so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
47+
// (and thus we can change the type in the future).
48+
#[must_use = "futures do nothing unless polled"]
49+
#[allow(missing_debug_implementations)]
50+
pub struct Tunneling<F, T> {
51+
#[pin]
52+
fut: BoxTunneling<T>,
53+
_marker: PhantomData<F>,
54+
}
55+
}
56+
57+
type BoxTunneling<T> = Pin<Box<dyn Future<Output = Result<T, TunnelError>> + Send>>;
58+
59+
impl<C> Tunnel<C> {
60+
/// Create a new Tunnel service.
61+
///
62+
/// This wraps an underlying connector, and stores the address of a
63+
/// tunneling proxy server.
64+
///
65+
/// A `Tunnel` can then be called with any destination. The `dst` passed to
66+
/// `call` will not be used to create the underlying connection, but will
67+
/// be used in an HTTP CONNECT request sent to the proxy destination.
68+
pub fn new(proxy_dst: Uri, connector: C) -> Self {
69+
Self {
70+
headers: Headers::Empty,
71+
inner: connector,
72+
proxy_dst,
73+
}
74+
}
75+
76+
/// Add `proxy-authorization` header value to the CONNECT request.
77+
pub fn with_auth(mut self, mut auth: HeaderValue) -> Self {
78+
// just in case the user forgot
79+
auth.set_sensitive(true);
80+
match self.headers {
81+
Headers::Empty => {
82+
self.headers = Headers::Auth(auth);
83+
}
84+
Headers::Auth(ref mut existing) => {
85+
*existing = auth;
86+
}
87+
Headers::Extra(ref mut extra) => {
88+
extra.insert(http::header::PROXY_AUTHORIZATION, auth);
89+
}
90+
}
91+
92+
self
93+
}
94+
95+
/// Add extra headers to be sent with the CONNECT request.
96+
///
97+
/// If existing headers have been set, these will be merged.
98+
pub fn with_headers(mut self, mut headers: HeaderMap) -> Self {
99+
match self.headers {
100+
Headers::Empty => {
101+
self.headers = Headers::Extra(headers);
102+
}
103+
Headers::Auth(auth) => {
104+
headers
105+
.entry(http::header::PROXY_AUTHORIZATION)
106+
.or_insert(auth);
107+
self.headers = Headers::Extra(headers);
108+
}
109+
Headers::Extra(ref mut extra) => {
110+
extra.extend(headers);
111+
}
112+
}
113+
114+
self
115+
}
116+
}
117+
118+
impl<C> Service<Uri> for Tunnel<C>
119+
where
120+
C: Service<Uri>,
121+
C::Future: Send + 'static,
122+
C::Response: Read + Write + Unpin + Send + 'static,
123+
C::Error: Into<Box<dyn StdError + Send + Sync>>,
124+
{
125+
type Response = C::Response;
126+
type Error = TunnelError;
127+
type Future = Tunneling<C::Future, C::Response>;
128+
129+
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
130+
futures_util::ready!(self.inner.poll_ready(cx))
131+
.map_err(|e| TunnelError::ConnectFailed(e.into()))?;
132+
Poll::Ready(Ok(()))
133+
}
134+
135+
fn call(&mut self, dst: Uri) -> Self::Future {
136+
let connecting = self.inner.call(self.proxy_dst.clone());
137+
let headers = self.headers.clone();
138+
139+
Tunneling {
140+
fut: Box::pin(async move {
141+
let conn = connecting
142+
.await
143+
.map_err(|e| TunnelError::ConnectFailed(e.into()))?;
144+
tunnel(
145+
conn,
146+
dst.host().ok_or(TunnelError::MissingHost)?,
147+
dst.port().map(|p| p.as_u16()).unwrap_or(443),
148+
&headers,
149+
)
150+
.await
151+
}),
152+
_marker: PhantomData,
153+
}
154+
}
155+
}
156+
157+
impl<F, T, E> Future for Tunneling<F, T>
158+
where
159+
F: Future<Output = Result<T, E>>,
160+
{
161+
type Output = Result<T, TunnelError>;
162+
163+
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
164+
self.project().fut.poll(cx)
165+
}
166+
}
167+
168+
async fn tunnel<T>(mut conn: T, host: &str, port: u16, headers: &Headers) -> Result<T, TunnelError>
169+
where
170+
T: Read + Write + Unpin,
171+
{
172+
let mut buf = format!(
173+
"\
174+
CONNECT {host}:{port} HTTP/1.1\r\n\
175+
Host: {host}:{port}\r\n\
176+
"
177+
)
178+
.into_bytes();
179+
180+
match headers {
181+
Headers::Auth(auth) => {
182+
buf.extend_from_slice(b"Proxy-Authorization: ");
183+
buf.extend_from_slice(auth.as_bytes());
184+
buf.extend_from_slice(b"\r\n");
185+
}
186+
Headers::Extra(extra) => {
187+
for (name, value) in extra {
188+
buf.extend_from_slice(name.as_str().as_bytes());
189+
buf.extend_from_slice(b": ");
190+
buf.extend_from_slice(value.as_bytes());
191+
buf.extend_from_slice(b"\r\n");
192+
}
193+
}
194+
Headers::Empty => (),
195+
}
196+
197+
// headers end
198+
buf.extend_from_slice(b"\r\n");
199+
200+
crate::rt::write_all(&mut conn, &buf)
201+
.await
202+
.map_err(TunnelError::Io)?;
203+
204+
let mut buf = [0; 8192];
205+
let mut pos = 0;
206+
207+
loop {
208+
let n = crate::rt::read(&mut conn, &mut buf[pos..])
209+
.await
210+
.map_err(TunnelError::Io)?;
211+
212+
if n == 0 {
213+
return Err(TunnelError::TunnelUnexpectedEof);
214+
}
215+
pos += n;
216+
217+
let recvd = &buf[..pos];
218+
if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
219+
if recvd.ends_with(b"\r\n\r\n") {
220+
return Ok(conn);
221+
}
222+
if pos == buf.len() {
223+
return Err(TunnelError::ProxyHeadersTooLong);
224+
}
225+
// else read more
226+
} else if recvd.starts_with(b"HTTP/1.1 407") {
227+
return Err(TunnelError::ProxyAuthRequired);
228+
} else {
229+
return Err(TunnelError::TunnelUnsuccessful);
230+
}
231+
}
232+
}
233+
234+
impl std::fmt::Display for TunnelError {
235+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236+
f.write_str("tunnel error: ")?;
237+
238+
f.write_str(match self {
239+
TunnelError::MissingHost => "missing destination host",
240+
TunnelError::ProxyAuthRequired => "proxy authorization required",
241+
TunnelError::ProxyHeadersTooLong => "proxy response headers too long",
242+
TunnelError::TunnelUnexpectedEof => "unexpected end of file",
243+
TunnelError::TunnelUnsuccessful => "unsuccessful",
244+
TunnelError::ConnectFailed(_) => "failed to create underlying connection",
245+
TunnelError::Io(_) => "io error establishing tunnel",
246+
})
247+
}
248+
}
249+
250+
impl std::error::Error for TunnelError {
251+
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
252+
match self {
253+
TunnelError::Io(ref e) => Some(e),
254+
TunnelError::ConnectFailed(ref e) => Some(&**e),
255+
_ => None,
256+
}
257+
}
258+
}

src/rt/io.rs

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
use std::marker::Unpin;
2+
use std::pin::Pin;
3+
use std::task::Poll;
4+
5+
use futures_util::future;
6+
use futures_util::ready;
7+
use hyper::rt::{Read, ReadBuf, Write};
8+
9+
pub(crate) async fn read<T>(io: &mut T, buf: &mut [u8]) -> Result<usize, std::io::Error>
10+
where
11+
T: Read + Unpin,
12+
{
13+
future::poll_fn(move |cx| {
14+
let mut buf = ReadBuf::new(buf);
15+
ready!(Pin::new(&mut *io).poll_read(cx, buf.unfilled()))?;
16+
Poll::Ready(Ok(buf.filled().len()))
17+
})
18+
.await
19+
}
20+
21+
pub(crate) async fn write_all<T>(io: &mut T, buf: &[u8]) -> Result<(), std::io::Error>
22+
where
23+
T: Write + Unpin,
24+
{
25+
let mut n = 0;
26+
future::poll_fn(move |cx| {
27+
while n < buf.len() {
28+
n += ready!(Pin::new(&mut *io).poll_write(cx, &buf[n..])?);
29+
}
30+
Poll::Ready(Ok(()))
31+
})
32+
.await
33+
}

src/rt/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
//! Runtime utilities
22
3+
#[cfg(feature = "client-legacy")]
4+
mod io;
5+
#[cfg(feature = "client-legacy")]
6+
pub(crate) use self::io::{read, write_all};
7+
38
#[cfg(feature = "tokio")]
49
pub mod tokio;
510

tests/proxy.rs

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use tokio::io::{AsyncReadExt, AsyncWriteExt};
2+
use tokio::net::TcpListener;
3+
use tower_service::Service;
4+
5+
use hyper_util::client::legacy::connect::{proxy::Tunnel, HttpConnector};
6+
7+
#[cfg(not(miri))]
8+
#[tokio::test]
9+
async fn test_tunnel_works() {
10+
let tcp = TcpListener::bind("127.0.0.1:0").await.expect("bind");
11+
let addr = tcp.local_addr().expect("local_addr");
12+
13+
let proxy_dst = format!("http://{}", addr).parse().expect("uri");
14+
let mut connector = Tunnel::new(proxy_dst, HttpConnector::new());
15+
let t1 = tokio::spawn(async move {
16+
let _conn = connector
17+
.call("https://hyper.rs".parse().unwrap())
18+
.await
19+
.expect("tunnel");
20+
});
21+
22+
let t2 = tokio::spawn(async move {
23+
let (mut io, _) = tcp.accept().await.expect("accept");
24+
let mut buf = [0u8; 64];
25+
let n = io.read(&mut buf).await.expect("read 1");
26+
assert_eq!(
27+
&buf[..n],
28+
b"CONNECT hyper.rs:443 HTTP/1.1\r\nHost: hyper.rs:443\r\n\r\n"
29+
);
30+
io.write_all(b"HTTP/1.1 200 OK\r\n\r\n")
31+
.await
32+
.expect("write 1");
33+
});
34+
35+
t1.await.expect("task 1");
36+
t2.await.expect("task 2");
37+
}

0 commit comments

Comments
 (0)