Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 750442c

Browse files
committedAug 24, 2023
Introduce async callbacks
We introduce tokio_boring::SslContextBuilderExt, with 2 methods: * set_async_select_certificate_callback * set_async_private_key_method
1 parent cb27511 commit 750442c

File tree

7 files changed

+576
-2
lines changed

7 files changed

+576
-2
lines changed
 

‎boring/src/ssl/mod.rs

+13
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,9 @@ pub struct SelectCertError(ffi::ssl_select_cert_result_t);
478478
impl SelectCertError {
479479
/// A fatal error occured and the handshake should be terminated.
480480
pub const ERROR: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_error);
481+
482+
/// The operation could not be completed and should be retried later.
483+
pub const RETRY: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_retry);
481484
}
482485

483486
/// Extension types, to be used with `ClientHello::get_extension`.
@@ -3197,6 +3200,11 @@ impl<S> MidHandshakeSslStream<S> {
31973200
self.stream.ssl()
31983201
}
31993202

3203+
/// Returns a mutable reference to the `Ssl` of the stream.
3204+
pub fn ssl_mut(&mut self) -> &mut SslRef {
3205+
self.stream.ssl_mut()
3206+
}
3207+
32003208
/// Returns the underlying error which interrupted this handshake.
32013209
pub fn error(&self) -> &Error {
32023210
&self.error
@@ -3451,6 +3459,11 @@ impl<S> SslStream<S> {
34513459
pub fn ssl(&self) -> &SslRef {
34523460
&self.ssl
34533461
}
3462+
3463+
/// Returns a mutable reference to the `Ssl` object associated with this stream.
3464+
pub fn ssl_mut(&mut self) -> &mut SslRef {
3465+
&mut self.ssl
3466+
}
34543467
}
34553468

34563469
impl<S: Read + Write> Read for SslStream<S> {

‎tokio-boring/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pq-experimental = ["boring/pq-experimental"]
3131
[dependencies]
3232
boring = { workspace = true }
3333
boring-sys = { workspace = true }
34+
once_cell = { workspace = true }
3435
tokio = { workspace = true }
3536

3637
[dev-dependencies]

‎tokio-boring/src/async_callbacks.rs

+262
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
use boring::ex_data::Index;
2+
use boring::ssl::{self, ClientHello, PrivateKeyMethod, Ssl, SslContextBuilder};
3+
use once_cell::sync::Lazy;
4+
use std::future::Future;
5+
use std::pin::Pin;
6+
use std::task::{ready, Context, Poll, Waker};
7+
8+
type BoxSelectCertFuture = ExDataFuture<Result<BoxSelectCertFinish, AsyncSelectCertError>>;
9+
10+
type BoxSelectCertFinish = Box<dyn FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError>>;
11+
12+
/// The type of futures returned by [`AsyncPrivateKeyMethod`] methods.
13+
pub type BoxPrivateKeyMethodFuture =
14+
ExDataFuture<Result<BoxPrivateKeyMethodFinish, AsyncPrivateKeyMethodError>>;
15+
16+
/// The type of callbacks returned by [`BoxPrivateKeyMethodFuture`].
17+
pub type BoxPrivateKeyMethodFinish =
18+
Box<dyn FnOnce(&mut ssl::SslRef, &mut [u8]) -> Result<usize, AsyncPrivateKeyMethodError>>;
19+
20+
type ExDataFuture<T> = Pin<Box<dyn Future<Output = T> + Send + Sync>>;
21+
22+
pub(crate) static TASK_WAKER_INDEX: Lazy<Index<Ssl, Option<Waker>>> =
23+
Lazy::new(|| Ssl::new_ex_index().unwrap());
24+
pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy<Index<Ssl, BoxSelectCertFuture>> =
25+
Lazy::new(|| Ssl::new_ex_index().unwrap());
26+
pub(crate) static SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX: Lazy<
27+
Index<Ssl, BoxPrivateKeyMethodFuture>,
28+
> = Lazy::new(|| Ssl::new_ex_index().unwrap());
29+
30+
/// Extensions to [`SslContextBuilder`].
31+
///
32+
/// This trait provides additional methods to use async callbacks with boring.
33+
pub trait SslContextBuilderExt: private::Sealed {
34+
/// Sets a callback that is called before most [`ClientHello`] processing
35+
/// and before the decision whether to resume a session is made. The
36+
/// callback may inspect the [`ClientHello`] and configure the connection.
37+
///
38+
/// This method uses a function that returns a future whose output is
39+
/// itself a closure that will be passed [`ClientHello`] to configure
40+
/// the connection based on the computations done in the future.
41+
///
42+
/// See [`SslContextBuilder::set_select_certificate_callback`] for the sync
43+
/// setter of this callback.
44+
fn set_async_select_certificate_callback<Init, Fut, Finish>(&mut self, callback: Init)
45+
where
46+
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
47+
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
48+
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static;
49+
50+
/// Configures a custom private key method on the context.
51+
///
52+
/// See [`AsyncPrivateKeyMethod`] for more details.
53+
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod);
54+
}
55+
56+
impl SslContextBuilderExt for SslContextBuilder {
57+
fn set_async_select_certificate_callback<Init, Fut, Finish>(&mut self, callback: Init)
58+
where
59+
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
60+
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
61+
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static,
62+
{
63+
self.set_select_certificate_callback(move |mut client_hello| {
64+
let fut_poll_result = with_ex_data_future(
65+
&mut client_hello,
66+
*SELECT_CERT_FUTURE_INDEX,
67+
ClientHello::ssl_mut,
68+
|client_hello| {
69+
let fut = callback(client_hello)?;
70+
71+
Ok(Box::pin(async move {
72+
Ok(Box::new(fut.await?) as BoxSelectCertFinish)
73+
}))
74+
},
75+
);
76+
77+
let fut_result = match fut_poll_result {
78+
Poll::Ready(fut_result) => fut_result,
79+
Poll::Pending => return Err(ssl::SelectCertError::RETRY),
80+
};
81+
82+
let finish = fut_result.or(Err(ssl::SelectCertError::ERROR))?;
83+
84+
finish(client_hello).or(Err(ssl::SelectCertError::ERROR))
85+
})
86+
}
87+
88+
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod) {
89+
self.set_private_key_method(AsyncPrivateKeyMethodBridge(Box::new(method)));
90+
}
91+
}
92+
93+
/// A fatal error to be returned from async select certificate callbacks.
94+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
95+
pub struct AsyncSelectCertError;
96+
97+
/// Describes async private key hooks. This is used to off-load signing
98+
/// operations to a custom, potentially asynchronous, backend. Metadata about the
99+
/// key such as the type and size are parsed out of the certificate.
100+
///
101+
/// See [`PrivateKeyMethod`] for the sync version of those hooks.
102+
///
103+
/// [`ssl_private_key_method_st`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#ssl_private_key_method_st
104+
pub trait AsyncPrivateKeyMethod: Send + Sync + 'static {
105+
/// Signs the message `input` using the specified signature algorithm.
106+
///
107+
/// This method uses a function that returns a future whose output is
108+
/// itself a closure that will be passed `ssl` and `output`
109+
/// to finish writing the signature.
110+
///
111+
/// See [`PrivateKeyMethod::sign`] for the sync version of this method.
112+
fn sign(
113+
&self,
114+
ssl: &mut ssl::SslRef,
115+
input: &[u8],
116+
signature_algorithm: ssl::SslSignatureAlgorithm,
117+
output: &mut [u8],
118+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>;
119+
120+
/// Decrypts `input`.
121+
///
122+
/// This method uses a function that returns a future whose output is
123+
/// itself a closure that will be passed `ssl` and `output`
124+
/// to finish decrypting the input.
125+
///
126+
/// See [`PrivateKeyMethod::decrypt`] for the sync version of this method.
127+
fn decrypt(
128+
&self,
129+
ssl: &mut ssl::SslRef,
130+
input: &[u8],
131+
output: &mut [u8],
132+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>;
133+
}
134+
135+
/// A fatal error to be returned from async private key methods.
136+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
137+
pub struct AsyncPrivateKeyMethodError;
138+
139+
struct AsyncPrivateKeyMethodBridge(Box<dyn AsyncPrivateKeyMethod>);
140+
141+
impl PrivateKeyMethod for AsyncPrivateKeyMethodBridge {
142+
fn sign(
143+
&self,
144+
ssl: &mut ssl::SslRef,
145+
input: &[u8],
146+
signature_algorithm: ssl::SslSignatureAlgorithm,
147+
output: &mut [u8],
148+
) -> Result<usize, ssl::PrivateKeyMethodError> {
149+
with_private_key_method(ssl, output, |ssl, output| {
150+
<dyn AsyncPrivateKeyMethod>::sign(&*self.0, ssl, input, signature_algorithm, output)
151+
})
152+
}
153+
154+
fn decrypt(
155+
&self,
156+
ssl: &mut ssl::SslRef,
157+
input: &[u8],
158+
output: &mut [u8],
159+
) -> Result<usize, ssl::PrivateKeyMethodError> {
160+
with_private_key_method(ssl, output, |ssl, output| {
161+
<dyn AsyncPrivateKeyMethod>::decrypt(&*self.0, ssl, input, output)
162+
})
163+
}
164+
165+
fn complete(
166+
&self,
167+
ssl: &mut ssl::SslRef,
168+
output: &mut [u8],
169+
) -> Result<usize, ssl::PrivateKeyMethodError> {
170+
with_private_key_method(ssl, output, |_, _| {
171+
// This should never be reached, if it does, that's a bug on boring's side,
172+
// which called `complete` without having been returned to with a pending
173+
// future from `sign` or `decrypt`.
174+
175+
if cfg!(debug_assertions) {
176+
panic!("BUG: boring called complete without a pending operation");
177+
}
178+
179+
Err(AsyncPrivateKeyMethodError)
180+
})
181+
}
182+
}
183+
184+
/// Creates and drives a private key method future.
185+
///
186+
/// This is a convenience function for the three methods of impl `PrivateKeyMethod``
187+
/// for `dyn AsyncPrivateKeyMethod`. It relies on [`with_ex_data_future`] to
188+
/// drive the future and then immediately calls the final [`BoxPrivateKeyMethodFinish`]
189+
/// when the future is ready.
190+
fn with_private_key_method(
191+
ssl: &mut ssl::SslRef,
192+
output: &mut [u8],
193+
create_fut: impl FnOnce(
194+
&mut ssl::SslRef,
195+
&mut [u8],
196+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>,
197+
) -> Result<usize, ssl::PrivateKeyMethodError> {
198+
let fut_poll_result = with_ex_data_future(
199+
ssl,
200+
*SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX,
201+
|ssl| ssl,
202+
|ssl| create_fut(ssl, output),
203+
);
204+
205+
let fut_result = match fut_poll_result {
206+
Poll::Ready(fut_result) => fut_result,
207+
Poll::Pending => return Err(ssl::PrivateKeyMethodError::RETRY),
208+
};
209+
210+
let finish = fut_result.or(Err(ssl::PrivateKeyMethodError::FAILURE))?;
211+
212+
finish(ssl, output).or(Err(ssl::PrivateKeyMethodError::FAILURE))
213+
}
214+
215+
/// Creates and drives a future stored in `ssl_handle`'s `Ssl` at ex data index `index`.
216+
///
217+
/// This function won't even bother storing the future in `index` if the future
218+
/// created by `create_fut` returns `Poll::Ready(_)` on the first poll call.
219+
fn with_ex_data_future<H, T, E>(
220+
ssl_handle: &mut H,
221+
index: Index<ssl::Ssl, ExDataFuture<Result<T, E>>>,
222+
get_ssl_mut: impl Fn(&mut H) -> &mut ssl::SslRef,
223+
create_fut: impl FnOnce(&mut H) -> Result<ExDataFuture<Result<T, E>>, E>,
224+
) -> Poll<Result<T, E>> {
225+
let ssl = get_ssl_mut(ssl_handle);
226+
let waker = ssl
227+
.ex_data(*TASK_WAKER_INDEX)
228+
.cloned()
229+
.flatten()
230+
.expect("task waker should be set");
231+
232+
let mut ctx = Context::from_waker(&waker);
233+
234+
match ssl.ex_data_mut(index) {
235+
Some(fut) => {
236+
let fut_result = ready!(fut.as_mut().poll(&mut ctx));
237+
238+
// NOTE(nox): For memory usage concerns, maybe we should implement
239+
// a way to remove the stored future from the `Ssl` value here?
240+
241+
Poll::Ready(fut_result)
242+
}
243+
None => {
244+
let mut fut = create_fut(ssl_handle)?;
245+
246+
match fut.as_mut().poll(&mut ctx) {
247+
Poll::Ready(fut_result) => Poll::Ready(fut_result),
248+
Poll::Pending => {
249+
get_ssl_mut(ssl_handle).set_ex_data(index, fut);
250+
251+
Poll::Pending
252+
}
253+
}
254+
}
255+
}
256+
}
257+
258+
mod private {
259+
pub trait Sealed {}
260+
}
261+
262+
impl private::Sealed for SslContextBuilder {}

‎tokio-boring/src/bridge.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
//! Bridge between sync IO traits and async tokio IO traits.
2-
32
use std::fmt;
43
use std::io;
54
use std::pin::Pin;
@@ -35,7 +34,7 @@ impl<S> AsyncStreamBridge<S> {
3534
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
3635
{
3736
let mut ctx =
38-
Context::from_waker(self.waker.as_ref().expect("missing task context pointer"));
37+
Context::from_waker(self.waker.as_ref().expect("BUG: missing waker in bridge"));
3938

4039
f(&mut ctx, Pin::new(&mut self.stream))
4140
}

‎tokio-boring/src/lib.rs

+16
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,14 @@ use std::pin::Pin;
2727
use std::task::{Context, Poll};
2828
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
2929

30+
mod async_callbacks;
3031
mod bridge;
3132

33+
use self::async_callbacks::TASK_WAKER_INDEX;
34+
pub use self::async_callbacks::{
35+
AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError,
36+
BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, SslContextBuilderExt,
37+
};
3238
use self::bridge::AsyncStreamBridge;
3339

3440
/// Asynchronously performs a client-side TLS handshake over the provided stream.
@@ -90,6 +96,11 @@ impl<S> SslStream<S> {
9096
self.0.ssl()
9197
}
9298

99+
/// Returns a mutable reference to the `Ssl` object associated with this stream.
100+
pub fn ssl_mut(&mut self) -> &mut SslRef {
101+
self.0.ssl_mut()
102+
}
103+
93104
/// Returns a shared reference to the underlying stream.
94105
pub fn get_ref(&self) -> &S {
95106
&self.0.get_ref().stream
@@ -285,15 +296,20 @@ where
285296
let mut mid_handshake = self.0.take().expect("future polled after completion");
286297

287298
mid_handshake.get_mut().set_waker(Some(ctx));
299+
mid_handshake
300+
.ssl_mut()
301+
.set_ex_data(*TASK_WAKER_INDEX, Some(ctx.waker().clone()));
288302

289303
match mid_handshake.handshake() {
290304
Ok(mut stream) => {
291305
stream.get_mut().set_waker(None);
306+
stream.ssl_mut().set_ex_data(*TASK_WAKER_INDEX, None);
292307

293308
Poll::Ready(Ok(SslStream(stream)))
294309
}
295310
Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
296311
mid_handshake.get_mut().set_waker(None);
312+
mid_handshake.ssl_mut().set_ex_data(*TASK_WAKER_INDEX, None);
297313

298314
self.0 = Some(mid_handshake);
299315

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
use boring::hash::MessageDigest;
2+
use boring::pkey::PKey;
3+
use boring::rsa::Padding;
4+
use boring::sign::{RsaPssSaltlen, Signer};
5+
use boring::ssl::{SslRef, SslSignatureAlgorithm};
6+
use futures::future;
7+
use tokio::task::yield_now;
8+
use tokio_boring::{
9+
AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, BoxPrivateKeyMethodFuture,
10+
SslContextBuilderExt,
11+
};
12+
13+
mod common;
14+
15+
use self::common::{connect, create_server, with_trivial_client_server_exchange};
16+
17+
#[allow(clippy::type_complexity)]
18+
struct Method {
19+
sign: Box<
20+
dyn Fn(
21+
&mut SslRef,
22+
&[u8],
23+
SslSignatureAlgorithm,
24+
&mut [u8],
25+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>
26+
+ Send
27+
+ Sync
28+
+ 'static,
29+
>,
30+
decrypt: Box<
31+
dyn Fn(
32+
&mut SslRef,
33+
&[u8],
34+
&mut [u8],
35+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>
36+
+ Send
37+
+ Sync
38+
+ 'static,
39+
>,
40+
}
41+
42+
impl Method {
43+
fn new() -> Self {
44+
Self {
45+
sign: Box::new(|_, _, _, _| unreachable!("called sign")),
46+
decrypt: Box::new(|_, _, _| unreachable!("called decrypt")),
47+
}
48+
}
49+
50+
fn sign(
51+
mut self,
52+
sign: impl Fn(
53+
&mut SslRef,
54+
&[u8],
55+
SslSignatureAlgorithm,
56+
&mut [u8],
57+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>
58+
+ Send
59+
+ Sync
60+
+ 'static,
61+
) -> Self {
62+
self.sign = Box::new(sign);
63+
64+
self
65+
}
66+
67+
#[allow(dead_code)]
68+
fn decrypt(
69+
mut self,
70+
decrypt: impl Fn(
71+
&mut SslRef,
72+
&[u8],
73+
&mut [u8],
74+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>
75+
+ Send
76+
+ Sync
77+
+ 'static,
78+
) -> Self {
79+
self.decrypt = Box::new(decrypt);
80+
81+
self
82+
}
83+
}
84+
85+
impl AsyncPrivateKeyMethod for Method {
86+
fn sign(
87+
&self,
88+
ssl: &mut SslRef,
89+
input: &[u8],
90+
signature_algorithm: SslSignatureAlgorithm,
91+
output: &mut [u8],
92+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError> {
93+
(self.sign)(ssl, input, signature_algorithm, output)
94+
}
95+
96+
fn decrypt(
97+
&self,
98+
ssl: &mut SslRef,
99+
input: &[u8],
100+
output: &mut [u8],
101+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError> {
102+
(self.decrypt)(ssl, input, output)
103+
}
104+
}
105+
106+
#[tokio::test]
107+
async fn test_sign_failure() {
108+
with_async_private_key_method_error(
109+
Method::new().sign(|_, _, _, _| Err(AsyncPrivateKeyMethodError)),
110+
)
111+
.await;
112+
}
113+
114+
#[tokio::test]
115+
async fn test_sign_future_failure() {
116+
with_async_private_key_method_error(
117+
Method::new().sign(|_, _, _, _| Ok(Box::pin(async { Err(AsyncPrivateKeyMethodError) }))),
118+
)
119+
.await;
120+
}
121+
122+
#[tokio::test]
123+
async fn test_sign_future_yield_failure() {
124+
with_async_private_key_method_error(Method::new().sign(|_, _, _, _| {
125+
Ok(Box::pin(async {
126+
yield_now().await;
127+
128+
Err(AsyncPrivateKeyMethodError)
129+
}))
130+
}))
131+
.await;
132+
}
133+
134+
#[tokio::test]
135+
async fn test_sign_ok() {
136+
with_trivial_client_server_exchange(|builder| {
137+
builder.set_async_private_key_method(Method::new().sign(
138+
|_, input, signature_algorithm, _| {
139+
assert_eq!(
140+
signature_algorithm,
141+
SslSignatureAlgorithm::RSA_PSS_RSAE_SHA256,
142+
);
143+
144+
let input = input.to_owned();
145+
146+
Ok(Box::pin(async move {
147+
Ok(Box::new(move |_: &mut SslRef, output: &mut [u8]| {
148+
Ok(sign_with_default_config(&input, output))
149+
}) as Box<_>)
150+
}))
151+
},
152+
));
153+
})
154+
.await;
155+
}
156+
157+
fn sign_with_default_config(input: &[u8], output: &mut [u8]) -> usize {
158+
let pkey = PKey::private_key_from_pem(include_bytes!("key.pem")).unwrap();
159+
let mut signer = Signer::new(MessageDigest::sha256(), &pkey).unwrap();
160+
161+
signer.set_rsa_padding(Padding::PKCS1_PSS).unwrap();
162+
signer
163+
.set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH)
164+
.unwrap();
165+
166+
signer.update(input).unwrap();
167+
168+
signer.sign(output).unwrap()
169+
}
170+
171+
async fn with_async_private_key_method_error(method: Method) {
172+
let (stream, addr) = create_server(move |builder| {
173+
builder.set_async_private_key_method(method);
174+
});
175+
176+
let server = async {
177+
let _err = stream.await.unwrap_err();
178+
};
179+
180+
let client = async {
181+
let _err = connect(addr, |builder| builder.set_ca_file("tests/cert.pem"))
182+
.await
183+
.unwrap_err();
184+
};
185+
186+
future::join(server, client).await;
187+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
use boring::ssl::ClientHello;
2+
use futures::future::{self, Pending};
3+
use futures::Future;
4+
use tokio::task::yield_now;
5+
use tokio_boring::{AsyncSelectCertError, SslContextBuilderExt};
6+
7+
mod common;
8+
9+
use self::common::{connect, create_server, with_trivial_client_server_exchange};
10+
11+
#[tokio::test]
12+
async fn test_async_select_certificate_callback_trivial() {
13+
with_trivial_client_server_exchange(|builder| {
14+
builder.set_async_select_certificate_callback(|_| {
15+
Ok(async move { Ok(|_: ClientHello<'_>| Ok(())) })
16+
});
17+
})
18+
.await;
19+
}
20+
21+
#[tokio::test]
22+
async fn test_async_select_certificate_callback_yield() {
23+
with_trivial_client_server_exchange(|builder| {
24+
builder.set_async_select_certificate_callback(|_| {
25+
Ok(async move {
26+
yield_now().await;
27+
28+
Ok(|_: ClientHello<'_>| Ok(()))
29+
})
30+
});
31+
})
32+
.await;
33+
}
34+
35+
#[tokio::test]
36+
async fn test_async_select_certificate_callback_return_error() {
37+
with_async_select_certificate_callback_error::<_, Pending<_>, fn(_: ClientHello<'_>) -> _>(
38+
|_| Err(AsyncSelectCertError),
39+
)
40+
.await;
41+
}
42+
43+
#[tokio::test]
44+
async fn test_async_select_certificate_callback_future_error() {
45+
with_async_select_certificate_callback_error::<_, _, fn(_: ClientHello<'_>) -> _>(|_| {
46+
Ok(async move { Err(AsyncSelectCertError) })
47+
})
48+
.await;
49+
}
50+
51+
#[tokio::test]
52+
async fn test_async_select_certificate_callback_future_yield_error() {
53+
with_async_select_certificate_callback_error::<_, _, fn(_: ClientHello<'_>) -> _>(|_| {
54+
Ok(async move {
55+
yield_now().await;
56+
57+
Err(AsyncSelectCertError)
58+
})
59+
})
60+
.await;
61+
}
62+
63+
#[tokio::test]
64+
async fn test_async_select_certificate_callback_finish_error() {
65+
with_async_select_certificate_callback_error(|_| {
66+
Ok(async move {
67+
yield_now().await;
68+
69+
Ok(|_: ClientHello<'_>| Err(AsyncSelectCertError))
70+
})
71+
})
72+
.await;
73+
}
74+
75+
async fn with_async_select_certificate_callback_error<Init, Fut, Finish>(callback: Init)
76+
where
77+
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
78+
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
79+
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static,
80+
{
81+
let (stream, addr) = create_server(|builder| {
82+
builder.set_async_select_certificate_callback(callback);
83+
});
84+
85+
let server = async {
86+
let _err = stream.await.unwrap_err();
87+
};
88+
89+
let client = async {
90+
let _err = connect(addr, |builder| builder.set_ca_file("tests/cert.pem"))
91+
.await
92+
.unwrap_err();
93+
};
94+
95+
future::join(server, client).await;
96+
}

0 commit comments

Comments
 (0)
Please sign in to comment.