Skip to content

Commit 2ba3db5

Browse files
committed
Introduce async callbacks
We introduce tokio_boring::SslContextBuilderExt, with 2 methods: * set_async_select_certificate_callback * set_async_private_key_method
1 parent 887f6fd commit 2ba3db5

8 files changed

+578
-3
lines changed

boring/src/ssl/mod.rs

+13
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,9 @@ pub struct SelectCertError(ffi::ssl_select_cert_result_t);
482482
impl SelectCertError {
483483
/// A fatal error occured and the handshake should be terminated.
484484
pub const ERROR: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_error);
485+
486+
/// The operation could not be completed and should be retried later.
487+
pub const RETRY: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_retry);
485488
}
486489

487490
/// Extension types, to be used with `ClientHello::get_extension`.
@@ -3260,6 +3263,11 @@ impl<S> MidHandshakeSslStream<S> {
32603263
self.stream.ssl()
32613264
}
32623265

3266+
/// Returns a mutable reference to the `Ssl` of the stream.
3267+
pub fn ssl_mut(&mut self) -> &mut SslRef {
3268+
self.stream.ssl_mut()
3269+
}
3270+
32633271
/// Returns the underlying error which interrupted this handshake.
32643272
pub fn error(&self) -> &Error {
32653273
&self.error
@@ -3514,6 +3522,11 @@ impl<S> SslStream<S> {
35143522
pub fn ssl(&self) -> &SslRef {
35153523
&self.ssl
35163524
}
3525+
3526+
/// Returns a mutable reference to the `Ssl` object associated with this stream.
3527+
pub fn ssl_mut(&mut self) -> &mut SslRef {
3528+
&mut self.ssl
3529+
}
35173530
}
35183531

35193532
impl<S: Read + Write> Read for SslStream<S> {

boring/src/ssl/test/private_key_method.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ fn test_sign_retry_complete_failure() {
189189
ErrorCode::WANT_PRIVATE_KEY_OPERATION
190190
);
191191

192-
let HandshakeError::WouldBlock(mid_handshake) = mid_handshake.handshake().unwrap_err() else {
192+
let HandshakeError::WouldBlock(mid_handshake) = mid_handshake.handshake().unwrap_err()
193+
else {
193194
panic!("should be WouldBlock");
194195
};
195196

tokio-boring/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ no-patches = ["boring/no-patches"]
3939
[dependencies]
4040
boring = { workspace = true }
4141
boring-sys = { workspace = true }
42+
once_cell = { workspace = true }
4243
tokio = { workspace = true }
4344

4445
[dev-dependencies]

tokio-boring/src/async_callbacks.rs

+262
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
use boring::ex_data::Index;
2+
use boring::ssl::{self, ClientHello, PrivateKeyMethod, Ssl, SslContextBuilder};
3+
use once_cell::sync::Lazy;
4+
use std::future::Future;
5+
use std::pin::Pin;
6+
use std::task::{ready, Context, Poll, Waker};
7+
8+
type BoxSelectCertFuture = ExDataFuture<Result<BoxSelectCertFinish, AsyncSelectCertError>>;
9+
10+
type BoxSelectCertFinish = Box<dyn FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError>>;
11+
12+
/// The type of futures returned by [`AsyncPrivateKeyMethod`] methods.
13+
pub type BoxPrivateKeyMethodFuture =
14+
ExDataFuture<Result<BoxPrivateKeyMethodFinish, AsyncPrivateKeyMethodError>>;
15+
16+
/// The type of callbacks returned by [`BoxPrivateKeyMethodFuture`].
17+
pub type BoxPrivateKeyMethodFinish =
18+
Box<dyn FnOnce(&mut ssl::SslRef, &mut [u8]) -> Result<usize, AsyncPrivateKeyMethodError>>;
19+
20+
type ExDataFuture<T> = Pin<Box<dyn Future<Output = T> + Send + Sync>>;
21+
22+
pub(crate) static TASK_WAKER_INDEX: Lazy<Index<Ssl, Option<Waker>>> =
23+
Lazy::new(|| Ssl::new_ex_index().unwrap());
24+
pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy<Index<Ssl, BoxSelectCertFuture>> =
25+
Lazy::new(|| Ssl::new_ex_index().unwrap());
26+
pub(crate) static SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX: Lazy<
27+
Index<Ssl, BoxPrivateKeyMethodFuture>,
28+
> = Lazy::new(|| Ssl::new_ex_index().unwrap());
29+
30+
/// Extensions to [`SslContextBuilder`].
31+
///
32+
/// This trait provides additional methods to use async callbacks with boring.
33+
pub trait SslContextBuilderExt: private::Sealed {
34+
/// Sets a callback that is called before most [`ClientHello`] processing
35+
/// and before the decision whether to resume a session is made. The
36+
/// callback may inspect the [`ClientHello`] and configure the connection.
37+
///
38+
/// This method uses a function that returns a future whose output is
39+
/// itself a closure that will be passed [`ClientHello`] to configure
40+
/// the connection based on the computations done in the future.
41+
///
42+
/// See [`SslContextBuilder::set_select_certificate_callback`] for the sync
43+
/// setter of this callback.
44+
fn set_async_select_certificate_callback<Init, Fut, Finish>(&mut self, callback: Init)
45+
where
46+
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
47+
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
48+
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static;
49+
50+
/// Configures a custom private key method on the context.
51+
///
52+
/// See [`AsyncPrivateKeyMethod`] for more details.
53+
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod);
54+
}
55+
56+
impl SslContextBuilderExt for SslContextBuilder {
57+
fn set_async_select_certificate_callback<Init, Fut, Finish>(&mut self, callback: Init)
58+
where
59+
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
60+
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
61+
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static,
62+
{
63+
self.set_select_certificate_callback(move |mut client_hello| {
64+
let fut_poll_result = with_ex_data_future(
65+
&mut client_hello,
66+
*SELECT_CERT_FUTURE_INDEX,
67+
ClientHello::ssl_mut,
68+
|client_hello| {
69+
let fut = callback(client_hello)?;
70+
71+
Ok(Box::pin(async move {
72+
Ok(Box::new(fut.await?) as BoxSelectCertFinish)
73+
}))
74+
},
75+
);
76+
77+
let fut_result = match fut_poll_result {
78+
Poll::Ready(fut_result) => fut_result,
79+
Poll::Pending => return Err(ssl::SelectCertError::RETRY),
80+
};
81+
82+
let finish = fut_result.or(Err(ssl::SelectCertError::ERROR))?;
83+
84+
finish(client_hello).or(Err(ssl::SelectCertError::ERROR))
85+
})
86+
}
87+
88+
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod) {
89+
self.set_private_key_method(AsyncPrivateKeyMethodBridge(Box::new(method)));
90+
}
91+
}
92+
93+
/// A fatal error to be returned from async select certificate callbacks.
94+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
95+
pub struct AsyncSelectCertError;
96+
97+
/// Describes async private key hooks. This is used to off-load signing
98+
/// operations to a custom, potentially asynchronous, backend. Metadata about the
99+
/// key such as the type and size are parsed out of the certificate.
100+
///
101+
/// See [`PrivateKeyMethod`] for the sync version of those hooks.
102+
///
103+
/// [`ssl_private_key_method_st`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#ssl_private_key_method_st
104+
pub trait AsyncPrivateKeyMethod: Send + Sync + 'static {
105+
/// Signs the message `input` using the specified signature algorithm.
106+
///
107+
/// This method uses a function that returns a future whose output is
108+
/// itself a closure that will be passed `ssl` and `output`
109+
/// to finish writing the signature.
110+
///
111+
/// See [`PrivateKeyMethod::sign`] for the sync version of this method.
112+
fn sign(
113+
&self,
114+
ssl: &mut ssl::SslRef,
115+
input: &[u8],
116+
signature_algorithm: ssl::SslSignatureAlgorithm,
117+
output: &mut [u8],
118+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>;
119+
120+
/// Decrypts `input`.
121+
///
122+
/// This method uses a function that returns a future whose output is
123+
/// itself a closure that will be passed `ssl` and `output`
124+
/// to finish decrypting the input.
125+
///
126+
/// See [`PrivateKeyMethod::decrypt`] for the sync version of this method.
127+
fn decrypt(
128+
&self,
129+
ssl: &mut ssl::SslRef,
130+
input: &[u8],
131+
output: &mut [u8],
132+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>;
133+
}
134+
135+
/// A fatal error to be returned from async private key methods.
136+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
137+
pub struct AsyncPrivateKeyMethodError;
138+
139+
struct AsyncPrivateKeyMethodBridge(Box<dyn AsyncPrivateKeyMethod>);
140+
141+
impl PrivateKeyMethod for AsyncPrivateKeyMethodBridge {
142+
fn sign(
143+
&self,
144+
ssl: &mut ssl::SslRef,
145+
input: &[u8],
146+
signature_algorithm: ssl::SslSignatureAlgorithm,
147+
output: &mut [u8],
148+
) -> Result<usize, ssl::PrivateKeyMethodError> {
149+
with_private_key_method(ssl, output, |ssl, output| {
150+
<dyn AsyncPrivateKeyMethod>::sign(&*self.0, ssl, input, signature_algorithm, output)
151+
})
152+
}
153+
154+
fn decrypt(
155+
&self,
156+
ssl: &mut ssl::SslRef,
157+
input: &[u8],
158+
output: &mut [u8],
159+
) -> Result<usize, ssl::PrivateKeyMethodError> {
160+
with_private_key_method(ssl, output, |ssl, output| {
161+
<dyn AsyncPrivateKeyMethod>::decrypt(&*self.0, ssl, input, output)
162+
})
163+
}
164+
165+
fn complete(
166+
&self,
167+
ssl: &mut ssl::SslRef,
168+
output: &mut [u8],
169+
) -> Result<usize, ssl::PrivateKeyMethodError> {
170+
with_private_key_method(ssl, output, |_, _| {
171+
// This should never be reached, if it does, that's a bug on boring's side,
172+
// which called `complete` without having been returned to with a pending
173+
// future from `sign` or `decrypt`.
174+
175+
if cfg!(debug_assertions) {
176+
panic!("BUG: boring called complete without a pending operation");
177+
}
178+
179+
Err(AsyncPrivateKeyMethodError)
180+
})
181+
}
182+
}
183+
184+
/// Creates and drives a private key method future.
185+
///
186+
/// This is a convenience function for the three methods of impl `PrivateKeyMethod``
187+
/// for `dyn AsyncPrivateKeyMethod`. It relies on [`with_ex_data_future`] to
188+
/// drive the future and then immediately calls the final [`BoxPrivateKeyMethodFinish`]
189+
/// when the future is ready.
190+
fn with_private_key_method(
191+
ssl: &mut ssl::SslRef,
192+
output: &mut [u8],
193+
create_fut: impl FnOnce(
194+
&mut ssl::SslRef,
195+
&mut [u8],
196+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>,
197+
) -> Result<usize, ssl::PrivateKeyMethodError> {
198+
let fut_poll_result = with_ex_data_future(
199+
ssl,
200+
*SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX,
201+
|ssl| ssl,
202+
|ssl| create_fut(ssl, output),
203+
);
204+
205+
let fut_result = match fut_poll_result {
206+
Poll::Ready(fut_result) => fut_result,
207+
Poll::Pending => return Err(ssl::PrivateKeyMethodError::RETRY),
208+
};
209+
210+
let finish = fut_result.or(Err(ssl::PrivateKeyMethodError::FAILURE))?;
211+
212+
finish(ssl, output).or(Err(ssl::PrivateKeyMethodError::FAILURE))
213+
}
214+
215+
/// Creates and drives a future stored in `ssl_handle`'s `Ssl` at ex data index `index`.
216+
///
217+
/// This function won't even bother storing the future in `index` if the future
218+
/// created by `create_fut` returns `Poll::Ready(_)` on the first poll call.
219+
fn with_ex_data_future<H, T, E>(
220+
ssl_handle: &mut H,
221+
index: Index<ssl::Ssl, ExDataFuture<Result<T, E>>>,
222+
get_ssl_mut: impl Fn(&mut H) -> &mut ssl::SslRef,
223+
create_fut: impl FnOnce(&mut H) -> Result<ExDataFuture<Result<T, E>>, E>,
224+
) -> Poll<Result<T, E>> {
225+
let ssl = get_ssl_mut(ssl_handle);
226+
let waker = ssl
227+
.ex_data(*TASK_WAKER_INDEX)
228+
.cloned()
229+
.flatten()
230+
.expect("task waker should be set");
231+
232+
let mut ctx = Context::from_waker(&waker);
233+
234+
match ssl.ex_data_mut(index) {
235+
Some(fut) => {
236+
let fut_result = ready!(fut.as_mut().poll(&mut ctx));
237+
238+
// NOTE(nox): For memory usage concerns, maybe we should implement
239+
// a way to remove the stored future from the `Ssl` value here?
240+
241+
Poll::Ready(fut_result)
242+
}
243+
None => {
244+
let mut fut = create_fut(ssl_handle)?;
245+
246+
match fut.as_mut().poll(&mut ctx) {
247+
Poll::Ready(fut_result) => Poll::Ready(fut_result),
248+
Poll::Pending => {
249+
get_ssl_mut(ssl_handle).set_ex_data(index, fut);
250+
251+
Poll::Pending
252+
}
253+
}
254+
}
255+
}
256+
}
257+
258+
mod private {
259+
pub trait Sealed {}
260+
}
261+
262+
impl private::Sealed for SslContextBuilder {}

tokio-boring/src/bridge.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
//! Bridge between sync IO traits and async tokio IO traits.
2-
32
use std::fmt;
43
use std::io;
54
use std::pin::Pin;
@@ -35,7 +34,7 @@ impl<S> AsyncStreamBridge<S> {
3534
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
3635
{
3736
let mut ctx =
38-
Context::from_waker(self.waker.as_ref().expect("missing task context pointer"));
37+
Context::from_waker(self.waker.as_ref().expect("BUG: missing waker in bridge"));
3938

4039
f(&mut ctx, Pin::new(&mut self.stream))
4140
}

tokio-boring/src/lib.rs

+16
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,14 @@ use std::pin::Pin;
2727
use std::task::{Context, Poll};
2828
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
2929

30+
mod async_callbacks;
3031
mod bridge;
3132

33+
use self::async_callbacks::TASK_WAKER_INDEX;
34+
pub use self::async_callbacks::{
35+
AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError,
36+
BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, SslContextBuilderExt,
37+
};
3238
use self::bridge::AsyncStreamBridge;
3339

3440
/// Asynchronously performs a client-side TLS handshake over the provided stream.
@@ -90,6 +96,11 @@ impl<S> SslStream<S> {
9096
self.0.ssl()
9197
}
9298

99+
/// Returns a mutable reference to the `Ssl` object associated with this stream.
100+
pub fn ssl_mut(&mut self) -> &mut SslRef {
101+
self.0.ssl_mut()
102+
}
103+
93104
/// Returns a shared reference to the underlying stream.
94105
pub fn get_ref(&self) -> &S {
95106
&self.0.get_ref().stream
@@ -285,15 +296,20 @@ where
285296
let mut mid_handshake = self.0.take().expect("future polled after completion");
286297

287298
mid_handshake.get_mut().set_waker(Some(ctx));
299+
mid_handshake
300+
.ssl_mut()
301+
.set_ex_data(*TASK_WAKER_INDEX, Some(ctx.waker().clone()));
288302

289303
match mid_handshake.handshake() {
290304
Ok(mut stream) => {
291305
stream.get_mut().set_waker(None);
306+
stream.ssl_mut().set_ex_data(*TASK_WAKER_INDEX, None);
292307

293308
Poll::Ready(Ok(SslStream(stream)))
294309
}
295310
Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
296311
mid_handshake.get_mut().set_waker(None);
312+
mid_handshake.ssl_mut().set_ex_data(*TASK_WAKER_INDEX, None);
297313

298314
self.0 = Some(mid_handshake);
299315

0 commit comments

Comments
 (0)