Skip to content

Commit 0df5320

Browse files
committed
feat: add support for processing handshake packets async via compute-heavy-future-executor
1 parent cd399ab commit 0df5320

File tree

11 files changed

+507
-21
lines changed

11 files changed

+507
-21
lines changed

.github/workflows/CI.yml

+2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ jobs:
6262
run: |
6363
cargo test --locked --all
6464
cargo test --locked -p tokio-rustls --features early-data --test early-data
65+
# we run all test suites against this feature since it shifts the default behavior globally
66+
cargo test --locked -p tokio-rustls --features compute-heavy-future-executor
6567
6668
lints:
6769
name: Lints

Cargo.lock

+23-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+8
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,18 @@ rust-version = "1.70"
1313
exclude = ["/.github", "/examples", "/scripts"]
1414

1515
[dependencies]
16+
# implicitly enables the tokio feature for compute-heavy-future-executor
17+
# (defaulting to strategy of spawn_blocking w/ concurrency conctorl)
18+
compute-heavy-future-executor = { version = "0.1", optional = true}
19+
pin-project-lite = { version = "0.2.15", optional = true }
1620
rustls = { version = "0.23.15", default-features = false, features = ["std"] }
1721
tokio = "1.0"
1822

1923
[features]
2024
default = ["logging", "tls12", "aws_lc_rs"]
2125
aws_lc_rs = ["rustls/aws_lc_rs"]
2226
aws-lc-rs = ["aws_lc_rs"] # Alias because Cargo features commonly use `-`
27+
compute-heavy-future-executor = ["dep:compute-heavy-future-executor", "pin-project-lite"]
2328
early-data = []
2429
fips = ["rustls/fips"]
2530
logging = ["rustls/logging"]
@@ -33,3 +38,6 @@ lazy_static = "1.1"
3338
rcgen = { version = "0.13", features = ["pem"] }
3439
tokio = { version = "1.0", features = ["full"] }
3540
webpki-roots = "0.26"
41+
42+
[patch.crates-io]
43+
compute-heavy-future-executor = { path = "../compute-heavy-future-executor" }

src/client.rs

+35-1
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,15 @@ where
6161
self.get_ref().0.as_raw_socket()
6262
}
6363
}
64+
#[cfg(feature = "early-data")]
65+
type TlsStreamExtras = Option<Waker>;
66+
#[cfg(not(feature = "early-data"))]
67+
type TlsStreamExtras = ();
6468

6569
impl<IO> IoSession for TlsStream<IO> {
6670
type Io = IO;
6771
type Session = ClientConnection;
72+
type Extras = TlsStreamExtras;
6873

6974
#[inline]
7075
fn skip_handshake(&self) -> bool {
@@ -80,6 +85,35 @@ impl<IO> IoSession for TlsStream<IO> {
8085
fn into_io(self) -> Self::Io {
8186
self.io
8287
}
88+
89+
#[inline]
90+
fn into_inner(self) -> (TlsState, Self::Io, Self::Session, Self::Extras) {
91+
#[cfg(feature = "early-data")]
92+
return (self.state, self.io, self.session, self.early_waker);
93+
94+
#[cfg(not(feature = "early-data"))]
95+
(self.state, self.io, self.session, ())
96+
}
97+
98+
#[inline]
99+
#[allow(unused_variables)]
100+
fn from_inner(
101+
state: TlsState,
102+
io: Self::Io,
103+
session: Self::Session,
104+
extras: Self::Extras,
105+
) -> Self {
106+
#[cfg(feature = "early-data")]
107+
return Self {
108+
io,
109+
session,
110+
state,
111+
early_waker: extras,
112+
};
113+
114+
#[cfg(not(feature = "early-data"))]
115+
Self { io, session, state }
116+
}
83117
}
84118

85119
impl<IO> AsyncRead for TlsStream<IO>
@@ -287,7 +321,7 @@ where
287321

288322
// complete handshake
289323
while stream.session.is_handshaking() {
290-
ready!(stream.handshake(cx))?;
324+
ready!(stream.handshake(cx, false))?;
291325
}
292326

293327
// write early data (fallback)

src/common/async_session.rs

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
use std::{
2+
future::Future,
3+
io,
4+
ops::{Deref, DerefMut},
5+
pin::Pin,
6+
task::{Context, Poll},
7+
};
8+
9+
use pin_project_lite::pin_project;
10+
use rustls::{ConnectionCommon, SideData};
11+
use tokio::io::{AsyncRead, AsyncWrite};
12+
13+
use crate::common::IoSession;
14+
15+
use super::{Stream, TlsState};
16+
17+
/// Full result of sync closure
18+
type SessionResult<S> = Result<S, (Option<S>, io::Error)>;
19+
/// Executor result wrapping sync closure result
20+
type SyncExecutorResult<S> = Result<SessionResult<S>, compute_heavy_future_executor::Error>;
21+
/// Future wrapping waiting on executor
22+
type SessionFuture<S> = Box<dyn Future<Output = SyncExecutorResult<S>> + Unpin + Send>;
23+
24+
pin_project! {
25+
/// Session is off doing compute-heavy sync work, such as initializing the session or processing handshake packets.
26+
/// Might be on another thread / external threadpool.
27+
///
28+
/// This future sleeps on it in current worker thread until it completes.
29+
pub(crate) struct AsyncSession<IS: IoSession> {
30+
#[pin]
31+
future: SessionFuture<IS::Session>,
32+
io: IS::Io,
33+
state: TlsState,
34+
extras: IS::Extras,
35+
}
36+
}
37+
38+
impl<IS, SD> AsyncSession<IS>
39+
where
40+
IS: IoSession + Unpin,
41+
IS::Io: AsyncRead + AsyncWrite + Unpin,
42+
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin + Send + 'static,
43+
SD: SideData,
44+
{
45+
pub(crate) fn process_packets(stream: IS) -> Self {
46+
let (state, io, mut session, extras) = stream.into_inner();
47+
48+
let closure = move || match session.process_new_packets() {
49+
Ok(_) => Ok(session),
50+
Err(err) => Err((
51+
Some(session),
52+
io::Error::new(io::ErrorKind::InvalidData, err),
53+
)),
54+
};
55+
56+
let future = compute_heavy_future_executor::execute_sync(closure);
57+
58+
Self {
59+
future: Box::new(Box::pin(future)),
60+
io,
61+
state,
62+
extras,
63+
}
64+
}
65+
66+
pub(crate) fn into_stream(
67+
mut self,
68+
session_result: Result<IS::Session, (Option<IS::Session>, io::Error)>,
69+
cx: &mut Context<'_>,
70+
) -> Result<IS, (io::Error, IS::Io)> {
71+
match session_result {
72+
Ok(session) => Ok(IS::from_inner(self.state, self.io, session, self.extras)),
73+
Err((Some(mut session), err)) => {
74+
// In case we have an alert to send describing this error,
75+
// try a last-gasp write -- but don't predate the primary
76+
// error.
77+
let mut tls_stream: Stream<'_, <IS as IoSession>::Io, <IS as IoSession>::Session> =
78+
Stream::new(&mut self.io, &mut session).set_eof(!self.state.readable());
79+
let _ = tls_stream.write_io(cx);
80+
81+
// still drop the tls session and return the io error only
82+
Err((err, self.io))
83+
}
84+
Err((None, err)) => Err((err, self.io)),
85+
}
86+
}
87+
88+
#[inline]
89+
pub fn get_ref(&self) -> &IS::Io {
90+
&self.io
91+
}
92+
93+
#[inline]
94+
pub fn get_mut(&mut self) -> &mut IS::Io {
95+
&mut self.io
96+
}
97+
}
98+
99+
impl<IS, SD> Future for AsyncSession<IS>
100+
where
101+
IS: IoSession + Unpin,
102+
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin + Send + 'static,
103+
SD: SideData,
104+
{
105+
type Output = Result<IS::Session, (Option<IS::Session>, io::Error)>;
106+
107+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
108+
let mut this = self.project();
109+
110+
match ready!(this.future.as_mut().poll(cx)) {
111+
Ok(session_res) => match session_res {
112+
Ok(res) => Poll::Ready(Ok(res)),
113+
// return any session along with the error,
114+
// so the caller can flush any remaining alerts in buffer to i/o
115+
Err((session, err)) => Poll::Ready(Err((
116+
session,
117+
io::Error::new(io::ErrorKind::InvalidData, err),
118+
))),
119+
},
120+
// We don't have a session to flush here because the executor ate it
121+
// TODO: not all errors should be modeled as io
122+
Err(executor_error) => Poll::Ready(Err((
123+
None,
124+
io::Error::new(io::ErrorKind::Other, executor_error),
125+
))),
126+
}
127+
}
128+
}

src/common/handshake.rs

+56-3
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,32 @@ use tokio::io::{AsyncRead, AsyncWrite};
1010

1111
use crate::common::{Stream, SyncWriteAdapter, TlsState};
1212

13+
#[cfg(feature = "compute-heavy-future-executor")]
14+
use super::async_session::AsyncSession;
15+
1316
pub(crate) trait IoSession {
1417
type Io;
1518
type Session;
19+
type Extras;
1620

1721
fn skip_handshake(&self) -> bool;
1822
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session);
1923
fn into_io(self) -> Self::Io;
24+
#[allow(dead_code)]
25+
fn into_inner(self) -> (TlsState, Self::Io, Self::Session, Self::Extras);
26+
#[allow(dead_code)]
27+
fn from_inner(
28+
state: TlsState,
29+
io: Self::Io,
30+
session: Self::Session,
31+
extras: Self::Extras,
32+
) -> Self;
2033
}
2134

2235
pub(crate) enum MidHandshake<IS: IoSession> {
2336
Handshaking(IS),
37+
#[cfg(feature = "compute-heavy-future-executor")]
38+
AsyncSession(AsyncSession<IS>),
2439
End,
2540
SendAlert {
2641
io: IS::Io,
@@ -32,12 +47,11 @@ pub(crate) enum MidHandshake<IS: IoSession> {
3247
error: io::Error,
3348
},
3449
}
35-
3650
impl<IS, SD> Future for MidHandshake<IS>
3751
where
3852
IS: IoSession + Unpin,
3953
IS::Io: AsyncRead + AsyncWrite + Unpin,
40-
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin,
54+
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin + Send + 'static,
4155
SD: SideData,
4256
{
4357
type Output = Result<IS, (io::Error, IS::Io)>;
@@ -47,6 +61,12 @@ where
4761

4862
let mut stream = match mem::replace(this, MidHandshake::End) {
4963
MidHandshake::Handshaking(stream) => stream,
64+
#[cfg(feature = "compute-heavy-future-executor")]
65+
MidHandshake::AsyncSession(mut async_session) => {
66+
let pinned = Pin::new(&mut async_session);
67+
let session_result = ready!(pinned.poll(cx));
68+
async_session.into_stream(session_result, cx)?
69+
}
5070
MidHandshake::SendAlert {
5171
mut io,
5272
mut alert,
@@ -74,6 +94,35 @@ where
7494
( $e:expr ) => {
7595
match $e {
7696
Poll::Ready(Ok(_)) => (),
97+
#[cfg(feature = "compute-heavy-future-executor")]
98+
Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::WouldBlock => {
99+
// TODO: downcast to decide on closure, for now we only do this for
100+
// process_packets
101+
102+
// decompose the stream and send the session to background executor
103+
let mut async_session = AsyncSession::process_packets(stream);
104+
105+
let pinned = Pin::new(&mut async_session);
106+
// poll once to kick off work
107+
match pinned.poll(cx) {
108+
// didn't need to sleep for async session
109+
Poll::Ready(res) => {
110+
let stream = async_session.into_stream(res, cx)?;
111+
// rather than continuing processing here,
112+
// we keep memory management simple and recompose
113+
// our future for a fresh poll
114+
*this = MidHandshake::Handshaking(stream);
115+
// tell executor to immediately poll us again
116+
cx.waker().wake_by_ref();
117+
return Poll::Pending;
118+
}
119+
// task is sleeping until async session is complete
120+
Poll::Pending => {
121+
*this = MidHandshake::AsyncSession(async_session);
122+
return Poll::Pending;
123+
}
124+
}
125+
}
77126
Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))),
78127
Poll::Pending => {
79128
*this = MidHandshake::Handshaking(stream);
@@ -83,8 +132,12 @@ where
83132
};
84133
}
85134

135+
86136
while tls_stream.session.is_handshaking() {
87-
try_poll!(tls_stream.handshake(cx));
137+
#[cfg(feature = "compute-heavy-future-executor")]
138+
try_poll!(tls_stream.handshake(cx, true));
139+
#[cfg(not(feature = "compute-heavy-future-executor"))]
140+
try_poll!(tls_stream.handshake(cx, false));
88141
}
89142

90143
try_poll!(Pin::new(&mut tls_stream).poll_flush(cx));

0 commit comments

Comments
 (0)