From 6e2b30cf30d09a2666aa7b901eb771a9b1ea679e Mon Sep 17 00:00:00 2001 From: Alessandro Ghedini Date: Fri, 23 Jun 2023 10:44:20 +0100 Subject: [PATCH 01/12] Introduce ssl::Error::would_block --- boring/src/ssl/error.rs | 4 ++++ boring/src/ssl/mod.rs | 50 +++++++++++++++++------------------------ 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/boring/src/ssl/error.rs b/boring/src/ssl/error.rs index 5fb916598..a7e6f9fcb 100644 --- a/boring/src/ssl/error.rs +++ b/boring/src/ssl/error.rs @@ -81,6 +81,10 @@ impl Error { _ => None, } } + + pub fn would_block(&self) -> bool { + matches!(self.code, ErrorCode::WANT_READ | ErrorCode::WANT_WRITE) + } } impl From for Error { diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 00aa82a04..94a7f6d53 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -3242,11 +3242,9 @@ impl MidHandshakeSslStream { Ok(self.stream) } else { self.error = self.stream.make_error(ret); - match self.error.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => { - Err(HandshakeError::WouldBlock(self)) - } - _ => Err(HandshakeError::Failure(self)), + match self.error.would_block() { + true => Err(HandshakeError::WouldBlock(self)), + false => Err(HandshakeError::Failure(self)), } } } @@ -3606,14 +3604,12 @@ where Ok(stream) } else { let error = stream.make_error(ret); - match error.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => { - Err(HandshakeError::WouldBlock(MidHandshakeSslStream { - stream, - error, - })) - } - _ => Err(HandshakeError::Failure(MidHandshakeSslStream { + match error.would_block() { + true => Err(HandshakeError::WouldBlock(MidHandshakeSslStream { + stream, + error, + })), + false => Err(HandshakeError::Failure(MidHandshakeSslStream { stream, error, })), @@ -3633,14 +3629,12 @@ where Ok(stream) } else { let error = stream.make_error(ret); - match error.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => { - Err(HandshakeError::WouldBlock(MidHandshakeSslStream { - stream, - error, - })) - } - _ => Err(HandshakeError::Failure(MidHandshakeSslStream { + match error.would_block() { + true => Err(HandshakeError::WouldBlock(MidHandshakeSslStream { + stream, + error, + })), + false => Err(HandshakeError::Failure(MidHandshakeSslStream { stream, error, })), @@ -3662,14 +3656,12 @@ where Ok(stream) } else { let error = stream.make_error(ret); - match error.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => { - Err(HandshakeError::WouldBlock(MidHandshakeSslStream { - stream, - error, - })) - } - _ => Err(HandshakeError::Failure(MidHandshakeSslStream { + match error.would_block() { + true => Err(HandshakeError::WouldBlock(MidHandshakeSslStream { + stream, + error, + })), + false => Err(HandshakeError::Failure(MidHandshakeSslStream { stream, error, })), From 6de6f1bed30717aef682602fdaf221cc9560e24c Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Thu, 14 Sep 2023 20:00:10 +0200 Subject: [PATCH 02/12] Panic on error when setting default curves list These lists are hardcoded and the calls have no business failing in the first place. --- boring/src/ssl/mod.rs | 10 ++++++---- boring/src/ssl/test/mod.rs | 8 ++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 94a7f6d53..3f8bb3dc3 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -2437,7 +2437,7 @@ impl SslRef { } #[cfg(feature = "kx-safe-default")] - fn client_set_default_curves_list(&mut self) -> Result<(), ErrorStack> { + fn client_set_default_curves_list(&mut self) { let curves = if cfg!(feature = "kx-client-pq-preferred") { if cfg!(feature = "kx-client-nist-required") { "P256Kyber768Draft00:P-256:P-384:P-521" @@ -2459,11 +2459,13 @@ impl SslRef { }; self.set_curves_list(curves) + .expect("invalid default client curves list"); } #[cfg(feature = "kx-safe-default")] - fn server_set_default_curves_list(&mut self) -> Result<(), ErrorStack> { + fn server_set_default_curves_list(&mut self) { self.set_curves_list("X25519Kyber768Draft00:P256Kyber768Draft00:X25519:P-256:P-384") + .expect("invalid default server curves list"); } /// Like [`SslContextBuilder::set_verify`]. @@ -3597,7 +3599,7 @@ where let mut stream = self.inner; #[cfg(feature = "kx-safe-default")] - stream.ssl.client_set_default_curves_list()?; + stream.ssl.client_set_default_curves_list(); let ret = unsafe { ffi::SSL_connect(stream.ssl.as_ptr()) }; if ret > 0 { @@ -3622,7 +3624,7 @@ where let mut stream = self.inner; #[cfg(feature = "kx-safe-default")] - stream.ssl.server_set_default_curves_list()?; + stream.ssl.server_set_default_curves_list(); let ret = unsafe { ffi::SSL_accept(stream.ssl.as_ptr()) }; if ret > 0 { diff --git a/boring/src/ssl/test/mod.rs b/boring/src/ssl/test/mod.rs index 367b8f4d0..5c986199a 100644 --- a/boring/src/ssl/test/mod.rs +++ b/boring/src/ssl/test/mod.rs @@ -1122,8 +1122,8 @@ fn client_set_default_curves_list() { let ssl_ctx = SslContextBuilder::new(SslMethod::tls()).unwrap().build(); let mut ssl = Ssl::new(&ssl_ctx).unwrap(); - ssl.client_set_default_curves_list() - .expect("Failed to set curves list. Is Kyber768 missing in boringSSL?") + // Panics if Kyber768 missing in boringSSL. + ssl.client_set_default_curves_list(); } #[cfg(feature = "kx-safe-default")] @@ -1132,6 +1132,6 @@ fn server_set_default_curves_list() { let ssl_ctx = SslContextBuilder::new(SslMethod::tls()).unwrap().build(); let mut ssl = Ssl::new(&ssl_ctx).unwrap(); - ssl.server_set_default_curves_list() - .expect("Failed to set curves list. Is Kyber768 missing in boringSSL?") + // Panics if Kyber768 missing in boringSSL. + ssl.server_set_default_curves_list(); } From 1b941c6256b4141c2fe53a0286e8846e76ce12d7 Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Thu, 3 Aug 2023 18:35:55 +0200 Subject: [PATCH 03/12] Introduce setup_accept and setup_connect These two new kinds of methods immediately return a MidHandshakeSslStream instead of actually initiating a handshake. This greatly simplifies loops around MidHandshakeSslStream::WouldBlock. --- boring/src/ssl/connector.rs | 71 ++++++++++++++-- boring/src/ssl/mod.rs | 157 +++++++++++++++++++++++------------- tokio-boring/src/lib.rs | 83 +++++-------------- 3 files changed, 186 insertions(+), 125 deletions(-) diff --git a/boring/src/ssl/connector.rs b/boring/src/ssl/connector.rs index 6bb58dabb..e910a324b 100644 --- a/boring/src/ssl/connector.rs +++ b/boring/src/ssl/connector.rs @@ -10,6 +10,8 @@ use crate::ssl::{ use crate::version; use std::net::IpAddr; +use super::MidHandshakeSslStream; + const FFDHE_2048: &str = " -----BEGIN DH PARAMETERS----- MIIBCAKCAQEA//////////+t+FRYortKmq/cViAnPTzx2LnFg84tNpWp4TZBFGQz @@ -99,11 +101,30 @@ impl SslConnector { /// Initiates a client-side TLS session on a stream. /// /// The domain is used for SNI and hostname verification. + pub fn setup_connect( + &self, + domain: &str, + stream: S, + ) -> Result, ErrorStack> + where + S: Read + Write, + { + self.configure()?.setup_connect(domain, stream) + } + + /// Attempts a client-side TLS session on a stream. + /// + /// The domain is used for SNI (if it is not an IP address) and hostname verification if enabled. + /// + /// This is a convenience method which combines [`Self::setup_connect`] and + /// [`MidHandshakeSslStream::handshake`]. pub fn connect(&self, domain: &str, stream: S) -> Result, HandshakeError> where S: Read + Write, { - self.configure()?.connect(domain, stream) + self.setup_connect(domain, stream) + .map_err(HandshakeError::SetupFailure)? + .handshake() } /// Returns a structure allowing for configuration of a single TLS session before connection. @@ -190,7 +211,7 @@ impl ConnectConfiguration { self.verify_hostname = verify_hostname; } - /// Returns an `Ssl` configured to connect to the provided domain. + /// Returns an [`Ssl`] configured to connect to the provided domain. /// /// The domain is used for SNI (if it is not an IP address) and hostname verification if enabled. pub fn into_ssl(mut self, domain: &str) -> Result { @@ -214,11 +235,33 @@ impl ConnectConfiguration { /// Initiates a client-side TLS session on a stream. /// /// The domain is used for SNI (if it is not an IP address) and hostname verification if enabled. + /// + /// This is a convenience method which combines [`Self::into_ssl`] and + /// [`Ssl::setup_connect`]. + pub fn setup_connect( + self, + domain: &str, + stream: S, + ) -> Result, ErrorStack> + where + S: Read + Write, + { + Ok(self.into_ssl(domain)?.setup_connect(stream)) + } + + /// Attempts a client-side TLS session on a stream. + /// + /// The domain is used for SNI (if it is not an IP address) and hostname verification if enabled. + /// + /// This is a convenience method which combines [`Self::setup_connect`] and + /// [`MidHandshakeSslStream::handshake`]. pub fn connect(self, domain: &str, stream: S) -> Result, HandshakeError> where S: Read + Write, { - self.into_ssl(domain)?.connect(stream) + self.setup_connect(domain, stream) + .map_err(HandshakeError::SetupFailure)? + .handshake() } } @@ -327,13 +370,29 @@ impl SslAcceptor { Ok(SslAcceptorBuilder(ctx)) } - /// Initiates a server-side TLS session on a stream. - pub fn accept(&self, stream: S) -> Result, HandshakeError> + /// Initiates a server-side TLS handshake on a stream. + /// + /// See [`Ssl::setup_accept`] for more details. + pub fn setup_accept(&self, stream: S) -> Result, ErrorStack> where S: Read + Write, { let ssl = Ssl::new(&self.0)?; - ssl.accept(stream) + + Ok(ssl.setup_accept(stream)) + } + + /// Attempts a server-side TLS handshake on a stream. + /// + /// This is a convenience method which combines [`Self::setup_accept`] and + /// [`MidHandshakeSslStream::handshake`]. + pub fn accept(&self, stream: S) -> Result, HandshakeError> + where + S: Read + Write, + { + self.setup_accept(stream) + .map_err(HandshakeError::SetupFailure)? + .handshake() } /// Consumes the `SslAcceptor`, returning the inner raw `SslContext`. diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 3f8bb3dc3..c9cd6f6ef 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -2320,10 +2320,11 @@ impl Ssl { } } - /// Creates a new `Ssl`. + /// Creates a new [`Ssl`]. /// - /// This corresponds to [`SSL_new`]. - /// This function does the same as [`Self:new()`] except that it takes &[SslContextRef]. + /// This corresponds to [`SSL_new`](`ffi::SSL_new`). + /// + /// This function does the same as [`Self:new`] except that it takes &[SslContextRef]. // Both functions exist for backward compatibility (no breaking API). pub fn new_from_ref(ctx: &SslContextRef) -> Result { unsafe { @@ -2337,34 +2338,52 @@ impl Ssl { } } - /// Initiates a client-side TLS handshake. + /// Initiates a client-side TLS handshake, returning a [`MidHandshakeSslStream`]. /// - /// This corresponds to [`SSL_connect`]. + /// This method is guaranteed to return without calling any callback defined + /// in the internal [`Ssl`] or [`SslContext`]. + /// + /// See [`SslStreamBuilder::setup_connect`] for more details. /// /// # Warning /// - /// OpenSSL's default configuration is insecure. It is highly recommended to use - /// `SslConnector` rather than `Ssl` directly, as it manages that configuration. + /// BoringSSL's default configuration is insecure. It is highly recommended to use + /// [`SslConnector`] rather than [`Ssl`] directly, as it manages that configuration. + pub fn setup_connect(self, stream: S) -> MidHandshakeSslStream + where + S: Read + Write, + { + SslStreamBuilder::new(self, stream).setup_connect() + } + + /// Attempts a client-side TLS handshake. + /// + /// This is a convenience method which combines [`Self::setup_connect`] and + /// [`MidHandshakeSslStream::handshake`]. + /// + /// # Warning /// - /// [`SSL_connect`]: https://www.openssl.org/docs/manmaster/man3/SSL_connect.html + /// OpenSSL's default configuration is insecure. It is highly recommended to use + /// [`SslConnector`] rather than `Ssl` directly, as it manages that configuration. pub fn connect(self, stream: S) -> Result, HandshakeError> where S: Read + Write, { - SslStreamBuilder::new(self, stream).connect() + self.setup_connect(stream).handshake() } /// Initiates a server-side TLS handshake. /// - /// This corresponds to [`SSL_accept`]. + /// This method is guaranteed to return without calling any callback defined + /// in the internal [`Ssl`] or [`SslContext`]. /// - /// # Warning + /// See [`SslStreamBuilder::setup_accept`] for more details. /// - /// OpenSSL's default configuration is insecure. It is highly recommended to use - /// `SslAcceptor` rather than `Ssl` directly, as it manages that configuration. + /// # Warning /// - /// [`SSL_accept`]: https://www.openssl.org/docs/manmaster/man3/SSL_accept.html - pub fn accept(self, stream: S) -> Result, HandshakeError> + /// BoringSSL's default configuration is insecure. It is highly recommended to use + /// [`SslAcceptor`] rather than [`Ssl`] directly, as it manages that configuration. + pub fn setup_accept(self, stream: S) -> MidHandshakeSslStream where S: Read + Write, { @@ -2383,7 +2402,25 @@ impl Ssl { } } - SslStreamBuilder::new(self, stream).accept() + SslStreamBuilder::new(self, stream).setup_accept() + } + + /// Attempts a server-side TLS handshake. + /// + /// This is a convenience method which combines [`Self::setup_accept`] and + /// [`MidHandshakeSslStream::handshake`]. + /// + /// # Warning + /// + /// OpenSSL's default configuration is insecure. It is highly recommended to use + /// `SslAcceptor` rather than `Ssl` directly, as it manages that configuration. + /// + /// [`SSL_accept`]: https://www.openssl.org/docs/manmaster/man3/SSL_accept.html + pub fn accept(self, stream: S) -> Result, HandshakeError> + where + S: Read + Write, + { + self.setup_accept(stream).handshake() } } @@ -3594,56 +3631,68 @@ where unsafe { ffi::SSL_set_accept_state(self.inner.ssl.as_ptr()) } } - /// See `Ssl::connect` - pub fn connect(self) -> Result, HandshakeError> { - let mut stream = self.inner; + /// Initiates a client-side TLS handshake, returning a [`MidHandshakeSslStream`]. + /// + /// This method calls [`Self::set_connect_state`] and returns without actually + /// initiating the handshake. The caller is then free to call + /// [`MidHandshakeSslStream`] and loop on [`HandshakeError::WouldBlock`]. + pub fn setup_connect(mut self) -> MidHandshakeSslStream { + self.set_connect_state(); #[cfg(feature = "kx-safe-default")] - stream.ssl.client_set_default_curves_list(); + self.inner.ssl.client_set_default_curves_list(); - let ret = unsafe { ffi::SSL_connect(stream.ssl.as_ptr()) }; - if ret > 0 { - Ok(stream) - } else { - let error = stream.make_error(ret); - match error.would_block() { - true => Err(HandshakeError::WouldBlock(MidHandshakeSslStream { - stream, - error, - })), - false => Err(HandshakeError::Failure(MidHandshakeSslStream { - stream, - error, - })), - } + MidHandshakeSslStream { + stream: self.inner, + error: Error { + code: ErrorCode::WANT_WRITE, + cause: Some(InnerError::Io(io::Error::new( + io::ErrorKind::WouldBlock, + "connect handshake has not started yet", + ))), + }, } } - /// See `Ssl::accept` - pub fn accept(self) -> Result, HandshakeError> { - let mut stream = self.inner; + /// Attempts a client-side TLS handshake. + /// + /// This is a convenience method which combines [`Self::setup_connect`] and + /// [`MidHandshakeSslStream::handshake`]. + pub fn connect(self) -> Result, HandshakeError> { + self.setup_connect().handshake() + } + + /// Initiates a server-side TLS handshake, returning a [`MidHandshakeSslStream`]. + /// + /// This method calls [`Self::set_accept_state`] and returns without actually + /// initiating the handshake. The caller is then free to call + /// [`MidHandshakeSslStream`] and loop on [`HandshakeError::WouldBlock`]. + pub fn setup_accept(mut self) -> MidHandshakeSslStream { + self.set_accept_state(); #[cfg(feature = "kx-safe-default")] - stream.ssl.server_set_default_curves_list(); + self.inner.ssl.server_set_default_curves_list(); - let ret = unsafe { ffi::SSL_accept(stream.ssl.as_ptr()) }; - if ret > 0 { - Ok(stream) - } else { - let error = stream.make_error(ret); - match error.would_block() { - true => Err(HandshakeError::WouldBlock(MidHandshakeSslStream { - stream, - error, - })), - false => Err(HandshakeError::Failure(MidHandshakeSslStream { - stream, - error, - })), - } + MidHandshakeSslStream { + stream: self.inner, + error: Error { + code: ErrorCode::WANT_READ, + cause: Some(InnerError::Io(io::Error::new( + io::ErrorKind::WouldBlock, + "accept handshake has not started yet", + ))), + }, } } + /// Attempts a server-side TLS handshake. + /// + /// This is a convenience method which combines [`Self::setup_accept`] and + /// [`MidHandshakeSslStream::handshake`]. + pub fn accept(self) -> Result, HandshakeError> { + self.setup_accept().handshake() + } + /// Initiates the handshake. /// /// This will fail if `set_accept_state` or `set_connect_state` was not called first. diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index a0dd58c52..f594231dd 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -13,6 +13,7 @@ #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] +use boring::error::ErrorStack; use boring::ssl::{ self, ConnectConfiguration, ErrorCode, MidHandshakeSslStream, ShutdownResult, SslAcceptor, SslRef, @@ -35,7 +36,7 @@ pub async fn connect( where S: AsyncRead + AsyncWrite + Unpin, { - handshake(|s| config.connect(domain, s), stream).await + handshake(|s| config.setup_connect(domain, s), stream).await } /// Asynchronously performs a server-side TLS handshake over the provided stream. @@ -43,24 +44,22 @@ pub async fn accept(acceptor: &SslAcceptor, stream: S) -> Result where S: AsyncRead + AsyncWrite + Unpin, { - handshake(|s| acceptor.accept(s), stream).await + handshake(|s| acceptor.setup_accept(s), stream).await } -async fn handshake(f: F, stream: S) -> Result, HandshakeError> +async fn handshake( + f: impl FnOnce(StreamWrapper) -> Result>, ErrorStack>, + stream: S, +) -> Result, HandshakeError> where - F: FnOnce( - StreamWrapper, - ) - -> Result>, ssl::HandshakeError>> - + Unpin, S: AsyncRead + AsyncWrite + Unpin, { - let start = StartHandshakeFuture(Some(StartHandshakeFutureInner { f, stream })); + let ongoing_handshake = Some( + f(StreamWrapper { stream, context: 0 }) + .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?, + ); - match start.await? { - StartedHandshake::Done(s) => Ok(s), - StartedHandshake::Mid(s) => HandshakeFuture(Some(s)).await, - } + HandshakeFuture(ongoing_handshake).await } struct StreamWrapper { @@ -334,53 +333,6 @@ where } } -enum StartedHandshake { - Done(SslStream), - Mid(MidHandshakeSslStream>), -} - -struct StartHandshakeFuture(Option>); - -struct StartHandshakeFutureInner { - f: F, - stream: S, -} - -impl Future for StartHandshakeFuture -where - F: FnOnce( - StreamWrapper, - ) - -> Result>, ssl::HandshakeError>> - + Unpin, - S: Unpin, -{ - type Output = Result, HandshakeError>; - - fn poll( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - ) -> Poll, HandshakeError>> { - let inner = self.0.take().expect("future polled after completion"); - - let stream = StreamWrapper { - stream: inner.stream, - context: ctx as *mut _ as usize, - }; - match (inner.f)(stream) { - Ok(mut s) => { - s.get_mut().context = 0; - Poll::Ready(Ok(StartedHandshake::Done(SslStream(s)))) - } - Err(ssl::HandshakeError::WouldBlock(mut s)) => { - s.get_mut().context = 0; - Poll::Ready(Ok(StartedHandshake::Mid(s))) - } - Err(e) => Poll::Ready(Err(HandshakeError(e))), - } - } -} - struct HandshakeFuture(Option>>); impl Future for HandshakeFuture @@ -389,21 +341,22 @@ where { type Output = Result, HandshakeError>; - fn poll( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - ) -> Poll, HandshakeError>> { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut s = self.0.take().expect("future polled after completion"); - s.get_mut().context = ctx as *mut _ as usize; + s.get_mut().context = cx as *mut _ as usize; + match s.handshake() { Ok(mut s) => { s.get_mut().context = 0; + Poll::Ready(Ok(SslStream(s))) } Err(ssl::HandshakeError::WouldBlock(mut s)) => { s.get_mut().context = 0; + self.0 = Some(s); + Poll::Pending } Err(e) => Poll::Ready(Err(HandshakeError(e))), From 92495fa65b20fcca5c045919cc60de8a0699dc43 Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Fri, 4 Aug 2023 13:42:09 +0200 Subject: [PATCH 04/12] Introduce helper module in tokio-boring tests --- tokio-boring/tests/client_server.rs | 88 ++++---------------------- tokio-boring/tests/common/mod.rs | 96 +++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 77 deletions(-) create mode 100644 tokio-boring/tests/common/mod.rs diff --git a/tokio-boring/tests/client_server.rs b/tokio-boring/tests/client_server.rs index 72c5a040c..925f9875e 100644 --- a/tokio-boring/tests/client_server.rs +++ b/tokio-boring/tests/client_server.rs @@ -1,11 +1,12 @@ -use boring::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod}; +use boring::ssl::{SslConnector, SslMethod}; use futures::future; -use std::future::Future; -use std::net::{SocketAddr, ToSocketAddrs}; -use std::pin::Pin; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::net::{TcpListener, TcpStream}; -use tokio_boring::{HandshakeError, SslStream}; +use std::net::ToSocketAddrs; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +mod common; + +use self::common::{connect, create_server, with_trivial_client_server_exchange}; #[tokio::test] async fn google() { @@ -33,75 +34,14 @@ async fn google() { assert!(response.ends_with("") || response.ends_with("")); } -fn create_server() -> ( - impl Future, HandshakeError>>, - SocketAddr, -) { - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - - listener.set_nonblocking(true).unwrap(); - - let listener = TcpListener::from_std(listener).unwrap(); - let addr = listener.local_addr().unwrap(); - - let server = async move { - let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - acceptor - .set_private_key_file("tests/key.pem", SslFiletype::PEM) - .unwrap(); - acceptor - .set_certificate_chain_file("tests/cert.pem") - .unwrap(); - let acceptor = acceptor.build(); - - let stream = listener.accept().await.unwrap().0; - - tokio_boring::accept(&acceptor, stream).await - }; - - (server, addr) -} - #[tokio::test] async fn server() { - let (stream, addr) = create_server(); - - let server = async { - let mut stream = stream.await.unwrap(); - let mut buf = [0; 4]; - stream.read_exact(&mut buf).await.unwrap(); - assert_eq!(&buf, b"asdf"); - - stream.write_all(b"jkl;").await.unwrap(); - - future::poll_fn(|ctx| Pin::new(&mut stream).poll_shutdown(ctx)) - .await - .unwrap(); - }; - - let client = async { - let mut connector = SslConnector::builder(SslMethod::tls()).unwrap(); - connector.set_ca_file("tests/cert.pem").unwrap(); - let config = connector.build().configure().unwrap(); - - let stream = TcpStream::connect(&addr).await.unwrap(); - let mut stream = tokio_boring::connect(config, "localhost", stream) - .await - .unwrap(); - - stream.write_all(b"asdf").await.unwrap(); - - let mut buf = vec![]; - stream.read_to_end(&mut buf).await.unwrap(); - assert_eq!(buf, b"jkl;"); - }; - - future::join(server, client).await; + with_trivial_client_server_exchange(|_| ()).await; } #[tokio::test] async fn handshake_error() { - let (stream, addr) = create_server(); + let (stream, addr) = create_server(|_| ()); let server = async { let err = stream.await.unwrap_err(); @@ -110,13 +50,7 @@ async fn handshake_error() { }; let client = async { - let connector = SslConnector::builder(SslMethod::tls()).unwrap(); - let config = connector.build().configure().unwrap(); - let stream = TcpStream::connect(&addr).await.unwrap(); - - let err = tokio_boring::connect(config, "localhost", stream) - .await - .unwrap_err(); + let err = connect(addr, |_| Ok(())).await.unwrap_err(); assert!(err.into_source_stream().is_some()); }; diff --git a/tokio-boring/tests/common/mod.rs b/tokio-boring/tests/common/mod.rs new file mode 100644 index 000000000..6ed394efe --- /dev/null +++ b/tokio-boring/tests/common/mod.rs @@ -0,0 +1,96 @@ +#![allow(dead_code)] + +use boring::error::ErrorStack; +use boring::ssl::{ + SslAcceptor, SslAcceptorBuilder, SslConnector, SslConnectorBuilder, SslFiletype, SslMethod, +}; +use futures::future::{self, Future}; +use std::net::SocketAddr; +use std::pin::Pin; +use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_boring::{HandshakeError, SslStream}; + +pub(crate) fn create_server( + setup: impl FnOnce(&mut SslAcceptorBuilder), +) -> ( + impl Future, HandshakeError>>, + SocketAddr, +) { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + + listener.set_nonblocking(true).unwrap(); + + let listener = TcpListener::from_std(listener).unwrap(); + let addr = listener.local_addr().unwrap(); + + let server = async move { + let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + + acceptor + .set_private_key_file("tests/key.pem", SslFiletype::PEM) + .unwrap(); + + acceptor + .set_certificate_chain_file("tests/cert.pem") + .unwrap(); + + setup(&mut acceptor); + + let acceptor = acceptor.build(); + + let stream = listener.accept().await.unwrap().0; + + tokio_boring::accept(&acceptor, stream).await + }; + + (server, addr) +} + +pub(crate) async fn connect( + addr: SocketAddr, + setup: impl FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>, +) -> Result, HandshakeError> { + let mut connector = SslConnector::builder(SslMethod::tls()).unwrap(); + + setup(&mut connector).unwrap(); + + let config = connector.build().configure().unwrap(); + + let stream = TcpStream::connect(&addr).await.unwrap(); + + tokio_boring::connect(config, "localhost", stream).await +} + +pub(crate) async fn with_trivial_client_server_exchange( + server_setup: impl FnOnce(&mut SslAcceptorBuilder), +) { + let (stream, addr) = create_server(server_setup); + + let server = async { + let mut stream = stream.await.unwrap(); + let mut buf = [0; 4]; + stream.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"asdf"); + + stream.write_all(b"jkl;").await.unwrap(); + + future::poll_fn(|ctx| Pin::new(&mut stream).poll_shutdown(ctx)) + .await + .unwrap(); + }; + + let client = async { + let mut stream = connect(addr, |builder| builder.set_ca_file("tests/cert.pem")) + .await + .unwrap(); + + stream.write_all(b"asdf").await.unwrap(); + + let mut buf = vec![]; + stream.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, b"jkl;"); + }; + + future::join(server, client).await; +} From 5b6bedf7a8a3ac2bfebdbc1450ef22685959caae Mon Sep 17 00:00:00 2001 From: Alessandro Ghedini Date: Fri, 23 Jun 2023 10:44:20 +0100 Subject: [PATCH 05/12] Add a few WouldBlock cases --- boring/src/ssl/error.rs | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/boring/src/ssl/error.rs b/boring/src/ssl/error.rs index a7e6f9fcb..e77ce291b 100644 --- a/boring/src/ssl/error.rs +++ b/boring/src/ssl/error.rs @@ -27,6 +27,17 @@ impl ErrorCode { /// Wait for write readiness and retry the operation. pub const WANT_WRITE: ErrorCode = ErrorCode(ffi::SSL_ERROR_WANT_WRITE); + pub const WANT_X509_LOOKUP: ErrorCode = ErrorCode(ffi::SSL_ERROR_WANT_X509_LOOKUP); + + pub const PENDING_SESSION: ErrorCode = ErrorCode(ffi::SSL_ERROR_PENDING_SESSION); + + pub const PENDING_CERTIFICATE: ErrorCode = ErrorCode(ffi::SSL_ERROR_PENDING_CERTIFICATE); + + pub const WANT_PRIVATE_KEY_OPERATION: ErrorCode = + ErrorCode(ffi::SSL_ERROR_WANT_PRIVATE_KEY_OPERATION); + + pub const PENDING_TICKET: ErrorCode = ErrorCode(ffi::SSL_ERROR_PENDING_TICKET); + /// A non-recoverable IO error occurred. pub const SYSCALL: ErrorCode = ErrorCode(ffi::SSL_ERROR_SYSCALL); @@ -83,7 +94,16 @@ impl Error { } pub fn would_block(&self) -> bool { - matches!(self.code, ErrorCode::WANT_READ | ErrorCode::WANT_WRITE) + matches!( + self.code, + ErrorCode::WANT_READ + | ErrorCode::WANT_WRITE + | ErrorCode::WANT_X509_LOOKUP + | ErrorCode::PENDING_SESSION + | ErrorCode::PENDING_CERTIFICATE + | ErrorCode::WANT_PRIVATE_KEY_OPERATION + | ErrorCode::PENDING_TICKET + ) } } From b6eb9f1855cf0e3df82aa93bad034ef71250dbf9 Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Fri, 4 Aug 2023 13:09:24 +0200 Subject: [PATCH 06/12] Introduce AsyncStreamBridge This encapsulates a bit better the unsafety of task context management to invoke async code from inside boring. --- tokio-boring/src/bridge.rs | 86 +++++++++++++++++++ tokio-boring/src/lib.rs | 165 +++++++++++++------------------------ 2 files changed, 143 insertions(+), 108 deletions(-) create mode 100644 tokio-boring/src/bridge.rs diff --git a/tokio-boring/src/bridge.rs b/tokio-boring/src/bridge.rs new file mode 100644 index 000000000..9de3fc224 --- /dev/null +++ b/tokio-boring/src/bridge.rs @@ -0,0 +1,86 @@ +//! Bridge between sync IO traits and async tokio IO traits. + +use std::fmt; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll, Waker}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +pub(crate) struct AsyncStreamBridge { + pub(crate) stream: S, + waker: Option, +} + +impl AsyncStreamBridge { + pub(crate) fn new(stream: S) -> Self + where + S: AsyncRead + AsyncWrite + Unpin, + { + Self { + stream, + waker: None, + } + } + + pub(crate) fn set_waker(&mut self, ctx: Option<&mut Context<'_>>) { + self.waker = ctx.map(|ctx| ctx.waker().clone()) + } + + /// # Panics + /// + /// Panics if the bridge has no waker. + pub(crate) fn with_context(&mut self, f: F) -> R + where + S: Unpin, + F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R, + { + let mut ctx = + Context::from_waker(self.waker.as_ref().expect("missing task context pointer")); + + f(&mut ctx, Pin::new(&mut self.stream)) + } +} + +impl io::Read for AsyncStreamBridge +where + S: AsyncRead + Unpin, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.with_context(|ctx, stream| { + let mut buf = ReadBuf::new(buf); + + match stream.poll_read(ctx, &mut buf)? { + Poll::Ready(()) => Ok(buf.filled().len()), + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } + }) + } +} + +impl io::Write for AsyncStreamBridge +where + S: AsyncWrite + Unpin, +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) { + Poll::Ready(r) => r, + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } + } + + fn flush(&mut self) -> io::Result<()> { + match self.with_context(|ctx, stream| stream.poll_flush(ctx)) { + Poll::Ready(r) => r, + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } + } +} + +impl fmt::Debug for AsyncStreamBridge +where + S: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.stream, fmt) + } +} diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index f594231dd..f437ee262 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -27,6 +27,10 @@ use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +mod bridge; + +use self::bridge::AsyncStreamBridge; + /// Asynchronously performs a client-side TLS handshake over the provided stream. pub async fn connect( config: ConnectConfiguration, @@ -48,94 +52,18 @@ where } async fn handshake( - f: impl FnOnce(StreamWrapper) -> Result>, ErrorStack>, + f: impl FnOnce( + AsyncStreamBridge, + ) -> Result>, ErrorStack>, stream: S, ) -> Result, HandshakeError> where S: AsyncRead + AsyncWrite + Unpin, { - let ongoing_handshake = Some( - f(StreamWrapper { stream, context: 0 }) - .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?, - ); - - HandshakeFuture(ongoing_handshake).await -} - -struct StreamWrapper { - stream: S, - context: usize, -} - -impl StreamWrapper { - /// # Safety - /// - /// Must be called with `context` set to a valid pointer to a live `Context` object, and the - /// wrapper must be pinned in memory. - unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) { - debug_assert_ne!(self.context, 0); - let stream = Pin::new_unchecked(&mut self.stream); - let context = &mut *(self.context as *mut _); - (stream, context) - } -} + let mid_handshake = f(AsyncStreamBridge::new(stream)) + .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?; -impl fmt::Debug for StreamWrapper -where - S: fmt::Debug, -{ - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Debug::fmt(&self.stream, fmt) - } -} - -impl StreamWrapper -where - S: Unpin, -{ - fn with_context(&mut self, f: F) -> R - where - F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R, - { - unsafe { - assert_ne!(self.context, 0); - let waker = &mut *(self.context as *mut _); - f(waker, Pin::new(&mut self.stream)) - } - } -} - -impl Read for StreamWrapper -where - S: AsyncRead + Unpin, -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let (stream, cx) = unsafe { self.parts() }; - let mut buf = ReadBuf::new(buf); - match stream.poll_read(cx, &mut buf)? { - Poll::Ready(()) => Ok(buf.filled().len()), - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } -} - -impl Write for StreamWrapper -where - S: AsyncWrite + Unpin, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self.with_context(|ctx, stream| stream.poll_flush(ctx)) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } + HandshakeFuture(Some(mid_handshake)).await } fn cvt(r: io::Result) -> Poll> { @@ -154,7 +82,7 @@ fn cvt(r: io::Result) -> Poll> { /// data. Bytes read from a `SslStream` are decrypted from `S` and bytes written /// to a `SslStream` are encrypted when passing through to `S`. #[derive(Debug)] -pub struct SslStream(ssl::SslStream>); +pub struct SslStream(ssl::SslStream>); impl SslStream { /// Returns a shared reference to the `Ssl` object associated with this stream. @@ -172,14 +100,21 @@ impl SslStream { &mut self.0.get_mut().stream } - fn with_context(&mut self, ctx: &mut Context<'_>, f: F) -> R + fn run_in_context(&mut self, ctx: &mut Context<'_>, f: F) -> R where - F: FnOnce(&mut ssl::SslStream>) -> R, + F: FnOnce(&mut ssl::SslStream>) -> R, { - self.0.get_mut().context = ctx as *mut _ as usize; - let r = f(&mut self.0); - self.0.get_mut().context = 0; - r + self.0.get_mut().set_waker(Some(ctx)); + + let result = f(&mut self.0); + + // NOTE(nox): This should also be executed when `f` panics, + // but it's not that important as boring segfaults on panics + // and we always set the context prior to doing anything with + // the inner async stream. + self.0.get_mut().set_waker(None); + + result } } @@ -195,8 +130,10 @@ where /// /// The caller must ensure the pointer is valid. pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self { - let stream = StreamWrapper { stream, context: 0 }; - SslStream(ssl::SslStream::from_raw_parts(ssl, stream)) + Self(ssl::SslStream::from_raw_parts( + ssl, + AsyncStreamBridge::new(stream), + )) } } @@ -209,7 +146,7 @@ where ctx: &mut Context<'_>, buf: &mut ReadBuf, ) -> Poll> { - self.with_context(ctx, |s| { + self.run_in_context(ctx, |s| { // This isn't really "proper", but rust-openssl doesn't currently expose a suitable interface even though // OpenSSL itself doesn't require the buffer to be initialized. So this is good enough for now. let slice = unsafe { @@ -239,15 +176,15 @@ where ctx: &mut Context, buf: &[u8], ) -> Poll> { - self.with_context(ctx, |s| cvt(s.write(buf))) + self.run_in_context(ctx, |s| cvt(s.write(buf))) } fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { - self.with_context(ctx, |s| cvt(s.flush())) + self.run_in_context(ctx, |s| cvt(s.flush())) } fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { - match self.with_context(ctx, |s| s.shutdown()) { + match self.run_in_context(ctx, |s| s.shutdown()) { Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {} Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {} Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => { @@ -265,7 +202,7 @@ where } /// The error type returned after a failed handshake. -pub struct HandshakeError(ssl::HandshakeError>); +pub struct HandshakeError(ssl::HandshakeError>); impl HandshakeError { /// Returns a shared reference to the `Ssl` object associated with this error. @@ -333,7 +270,10 @@ where } } -struct HandshakeFuture(Option>>); +/// Future for an ongoing TLS handshake. +/// +/// See [`connect`] and [`accept`]. +pub struct HandshakeFuture(Option>>); impl Future for HandshakeFuture where @@ -341,25 +281,34 @@ where { type Output = Result, HandshakeError>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut s = self.0.take().expect("future polled after completion"); + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + let mut mid_handshake = self.0.take().expect("future polled after completion"); - s.get_mut().context = cx as *mut _ as usize; + mid_handshake.get_mut().set_waker(Some(ctx)); - match s.handshake() { - Ok(mut s) => { - s.get_mut().context = 0; + match mid_handshake.handshake() { + Ok(mut stream) => { + stream.get_mut().set_waker(None); - Poll::Ready(Ok(SslStream(s))) + Poll::Ready(Ok(SslStream(stream))) } - Err(ssl::HandshakeError::WouldBlock(mut s)) => { - s.get_mut().context = 0; + Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => { + mid_handshake.get_mut().set_waker(None); - self.0 = Some(s); + self.0 = Some(mid_handshake); Poll::Pending } - Err(e) => Poll::Ready(Err(HandshakeError(e))), + Err(ssl::HandshakeError::Failure(mut mid_handshake)) => { + mid_handshake.get_mut().set_waker(None); + + Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure( + mid_handshake, + )))) + } + Err(err @ ssl::HandshakeError::SetupFailure(_)) => { + Poll::Ready(Err(HandshakeError(err))) + } } } } From b15d305a9f80b543a0efec7ddbc1859992524d36 Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Fri, 28 Jul 2023 11:30:19 +0200 Subject: [PATCH 07/12] Change signature for set_select_certificate_callback To handle lifetimes better and allow returning a &mut SslRef from the client hello struct passed to the closure from SslContextBuilder::set_select_certificate_callback, we make the ClientHello struct itself own a reference to the FFI client hello struct. --- boring/src/ssl/callbacks.rs | 9 ++++----- boring/src/ssl/mod.rs | 15 ++++++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/boring/src/ssl/callbacks.rs b/boring/src/ssl/callbacks.rs index e87b151f3..1e7a64da6 100644 --- a/boring/src/ssl/callbacks.rs +++ b/boring/src/ssl/callbacks.rs @@ -223,14 +223,13 @@ pub(super) unsafe extern "C" fn raw_select_cert( client_hello: *const ffi::SSL_CLIENT_HELLO, ) -> ffi::ssl_select_cert_result_t where - F: Fn(&ClientHello) -> Result<(), SelectCertError> + Sync + Send + 'static, + F: Fn(ClientHello<'_>) -> Result<(), SelectCertError> + Sync + Send + 'static, { // SAFETY: boring provides valid inputs. - let client_hello = unsafe { &*(client_hello as *const ClientHello) }; + let client_hello = ClientHello(unsafe { &*client_hello }); - let callback = client_hello - .ssl() - .ssl_context() + let ssl_context = client_hello.ssl().ssl_context().to_owned(); + let callback = ssl_context .ex_data(SslContext::cached_ex_index::()) .expect("BUG: select cert callback missing"); diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index c9cd6f6ef..34925fb82 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -1370,6 +1370,7 @@ impl SslContextBuilder { ); } } + /// Sets a callback that is called before most ClientHello processing and before the decision whether /// to resume a session is made. The callback may inspect the ClientHello and configure the /// connection. @@ -1379,7 +1380,7 @@ impl SslContextBuilder { /// [`SSL_CTX_set_select_certificate_cb`]: https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_set_select_certificate_cb.html pub fn set_select_certificate_callback(&mut self, callback: F) where - F: Fn(&ClientHello) -> Result<(), SelectCertError> + Sync + Send + 'static, + F: Fn(ClientHello<'_>) -> Result<(), SelectCertError> + Sync + Send + 'static, { unsafe { self.set_ex_data(SslContext::cached_ex_index::(), callback); @@ -1959,9 +1960,9 @@ pub struct CipherBits { } #[repr(transparent)] -pub struct ClientHello(ffi::SSL_CLIENT_HELLO); +pub struct ClientHello<'ssl>(&'ssl ffi::SSL_CLIENT_HELLO); -impl ClientHello { +impl ClientHello<'_> { /// Returns the data of a given extension, if present. /// /// This corresponds to [`SSL_early_callback_ctx_extension_get`]. @@ -1972,7 +1973,7 @@ impl ClientHello { let mut ptr = ptr::null(); let mut len = 0; let result = - ffi::SSL_early_callback_ctx_extension_get(&self.0, ext_type.0, &mut ptr, &mut len); + ffi::SSL_early_callback_ctx_extension_get(self.0, ext_type.0, &mut ptr, &mut len); if result == 0 { return None; } @@ -1980,7 +1981,11 @@ impl ClientHello { } } - fn ssl(&self) -> &SslRef { + pub fn ssl_mut(&mut self) -> &mut SslRef { + unsafe { SslRef::from_ptr_mut(self.0.ssl) } + } + + pub fn ssl(&self) -> &SslRef { unsafe { SslRef::from_ptr(self.0.ssl) } } From 7038f177bacce92291c754b8fca6053d385c8d71 Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Wed, 2 Aug 2023 10:36:26 +0200 Subject: [PATCH 08/12] Implement SslContextBuilder::set_private_key_method --- boring/src/ssl/callbacks.rs | 111 ++++++++- boring/src/ssl/mod.rs | 96 ++++++++ boring/src/ssl/test/mod.rs | 24 +- boring/src/ssl/test/private_key_method.rs | 282 ++++++++++++++++++++++ boring/src/ssl/test/server.rs | 33 ++- 5 files changed, 514 insertions(+), 32 deletions(-) create mode 100644 boring/src/ssl/test/private_key_method.rs diff --git a/boring/src/ssl/callbacks.rs b/boring/src/ssl/callbacks.rs index 1e7a64da6..dc9f2d53b 100644 --- a/boring/src/ssl/callbacks.rs +++ b/boring/src/ssl/callbacks.rs @@ -1,6 +1,13 @@ #![forbid(unsafe_op_in_unsafe_fn)] +use super::{ + AlpnError, ClientHello, PrivateKeyMethod, PrivateKeyMethodError, SelectCertError, SniError, + Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef, + SslSignatureAlgorithm, SESSION_CTX_INDEX, +}; +use crate::error::ErrorStack; use crate::ffi; +use crate::x509::{X509StoreContext, X509StoreContextRef}; use foreign_types::ForeignType; use foreign_types::ForeignTypeRef; use libc::c_char; @@ -12,19 +19,7 @@ use std::slice; use std::str; use std::sync::Arc; -use crate::error::ErrorStack; -use crate::ssl::AlpnError; -use crate::ssl::{ClientHello, SelectCertError}; -use crate::ssl::{ - SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef, - SESSION_CTX_INDEX, -}; -use crate::x509::{X509StoreContext, X509StoreContextRef}; - -pub(super) unsafe extern "C" fn raw_verify( - preverify_ok: c_int, - x509_ctx: *mut ffi::X509_STORE_CTX, -) -> c_int +pub extern "C" fn raw_verify(preverify_ok: c_int, x509_ctx: *mut ffi::X509_STORE_CTX) -> c_int where F: Fn(bool, &mut X509StoreContextRef) -> bool + 'static + Sync + Send, { @@ -372,3 +367,93 @@ where callback(ssl, line); } + +pub(super) unsafe extern "C" fn raw_sign( + ssl: *mut ffi::SSL, + out: *mut u8, + out_len: *mut usize, + max_out: usize, + signature_algorithm: u16, + in_: *const u8, + in_len: usize, +) -> ffi::ssl_private_key_result_t +where + M: PrivateKeyMethod, +{ + // SAFETY: boring provides valid inputs. + let input = unsafe { slice::from_raw_parts(in_, in_len) }; + + let signature_algorithm = SslSignatureAlgorithm(signature_algorithm); + + let callback = |method: &M, ssl: &mut _, output: &mut _| { + method.sign(ssl, input, signature_algorithm, output) + }; + + // SAFETY: boring provides valid inputs. + unsafe { raw_private_key_callback(ssl, out, out_len, max_out, callback) } +} + +pub(super) unsafe extern "C" fn raw_decrypt( + ssl: *mut ffi::SSL, + out: *mut u8, + out_len: *mut usize, + max_out: usize, + in_: *const u8, + in_len: usize, +) -> ffi::ssl_private_key_result_t +where + M: PrivateKeyMethod, +{ + // SAFETY: boring provides valid inputs. + let input = unsafe { slice::from_raw_parts(in_, in_len) }; + + let callback = |method: &M, ssl: &mut _, output: &mut _| method.decrypt(ssl, input, output); + + // SAFETY: boring provides valid inputs. + unsafe { raw_private_key_callback(ssl, out, out_len, max_out, callback) } +} + +pub(super) unsafe extern "C" fn raw_complete( + ssl: *mut ffi::SSL, + out: *mut u8, + out_len: *mut usize, + max_out: usize, +) -> ffi::ssl_private_key_result_t +where + M: PrivateKeyMethod, +{ + // SAFETY: boring provides valid inputs. + unsafe { raw_private_key_callback::(ssl, out, out_len, max_out, M::complete) } +} + +unsafe fn raw_private_key_callback( + ssl: *mut ffi::SSL, + out: *mut u8, + out_len: *mut usize, + max_out: usize, + callback: impl FnOnce(&M, &mut SslRef, &mut [u8]) -> Result, +) -> ffi::ssl_private_key_result_t +where + M: PrivateKeyMethod, +{ + // SAFETY: boring provides valid inputs. + let ssl = unsafe { SslRef::from_ptr_mut(ssl) }; + let output = unsafe { slice::from_raw_parts_mut(out, max_out) }; + let out_len = unsafe { &mut *out_len }; + + let ssl_context = ssl.ssl_context().to_owned(); + let method = ssl_context + .ex_data(SslContext::cached_ex_index::()) + .expect("BUG: private key method missing"); + + match callback(method, ssl, output) { + Ok(written) => { + assert!(written <= max_out); + + *out_len = written; + + ffi::ssl_private_key_result_t::ssl_private_key_success + } + Err(err) => err.0, + } +} diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 34925fb82..c0a6c3413 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -1391,6 +1391,31 @@ impl SslContextBuilder { } } + /// Configures a custom private key method on the context. + /// + /// See [`PrivateKeyMethod`] for more details. + /// + /// This corresponds to [`SSL_CTX_set_private_key_method`] + /// + /// [`SSL_CTX_set_private_key_method`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#SSL_CTX_set_private_key_method + pub fn set_private_key_method(&mut self, method: M) + where + M: PrivateKeyMethod, + { + unsafe { + self.set_ex_data(SslContext::cached_ex_index::(), method); + + ffi::SSL_CTX_set_private_key_method( + self.as_ptr(), + &ffi::SSL_PRIVATE_KEY_METHOD { + sign: Some(callbacks::raw_sign::), + decrypt: Some(callbacks::raw_decrypt::), + complete: Some(callbacks::raw_complete::), + }, + ) + } + } + /// Checks for consistency between the private key and certificate. /// /// This corresponds to [`SSL_CTX_check_private_key`]. @@ -3790,6 +3815,77 @@ bitflags! { } } +/// Describes private key hooks. This is used to off-load signing operations to +/// a custom, potentially asynchronous, backend. Metadata about the key such as +/// the type and size are parsed out of the certificate. +/// +/// Corresponds to [`ssl_private_key_method_st`]. +/// +/// [`ssl_private_key_method_st`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#ssl_private_key_method_st +pub trait PrivateKeyMethod: Send + Sync + 'static { + /// Signs the message `input` using the specified signature algorithm. + /// + /// On success, it returns `Ok(written)` where `written` is the number of + /// bytes written into `output`. On failure, it returns + /// `Err(PrivateKeyMethodError::FAILURE)`. If the operation has not completed, + /// it returns `Err(PrivateKeyMethodError::RETRY)`. + /// + /// The caller should arrange for the high-level operation on `ssl` to be + /// retried when the operation is completed. This will result in a call to + /// [`Self::complete`]. + fn sign( + &self, + ssl: &mut SslRef, + input: &[u8], + signature_algorithm: SslSignatureAlgorithm, + output: &mut [u8], + ) -> Result; + + /// Decrypts `input`. + /// + /// On success, it returns `Ok(written)` where `written` is the number of + /// bytes written into `output`. On failure, it returns + /// `Err(PrivateKeyMethodError::FAILURE)`. If the operation has not completed, + /// it returns `Err(PrivateKeyMethodError::RETRY)`. + /// + /// The caller should arrange for the high-level operation on `ssl` to be + /// retried when the operation is completed. This will result in a call to + /// [`Self::complete`]. + /// + /// This method only works with RSA keys and should perform a raw RSA + /// decryption operation with no padding. + // NOTE(nox): What does it mean that it is an error? + fn decrypt( + &self, + ssl: &mut SslRef, + input: &[u8], + output: &mut [u8], + ) -> Result; + + /// Completes a pending operation. + /// + /// On success, it returns `Ok(written)` where `written` is the number of + /// bytes written into `output`. On failure, it returns + /// `Err(PrivateKeyMethodError::FAILURE)`. If the operation has not completed, + /// it returns `Err(PrivateKeyMethodError::RETRY)`. + /// + /// This method may be called arbitrarily many times before completion. + fn complete(&self, ssl: &mut SslRef, output: &mut [u8]) + -> Result; +} + +/// An error returned from a private key method. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct PrivateKeyMethodError(ffi::ssl_private_key_result_t); + +impl PrivateKeyMethodError { + /// A fatal error occured and the handshake should be terminated. + pub const FAILURE: Self = Self(ffi::ssl_private_key_result_t::ssl_private_key_failure); + + /// The operation could not be completed and should be retried later. + pub const RETRY: Self = Self(ffi::ssl_private_key_result_t::ssl_private_key_retry); +} + use crate::ffi::{SSL_CTX_up_ref, SSL_SESSION_get_master_key, SSL_SESSION_up_ref, SSL_is_server}; use crate::ffi::{DTLS_method, TLS_client_method, TLS_method, TLS_server_method}; diff --git a/boring/src/ssl/test/mod.rs b/boring/src/ssl/test/mod.rs index 5c986199a..a68c3dc70 100644 --- a/boring/src/ssl/test/mod.rs +++ b/boring/src/ssl/test/mod.rs @@ -34,6 +34,7 @@ use crate::x509::store::X509StoreBuilder; use crate::x509::verify::X509CheckFlags; use crate::x509::{X509Name, X509StoreContext, X509VerifyResult, X509}; +mod private_key_method; mod server; static ROOT_CERT: &[u8] = include_bytes!("../../../test/root-ca.pem"); @@ -55,9 +56,7 @@ fn verify_untrusted() { #[test] fn verify_trusted() { let server = Server::builder().build(); - - let mut client = server.client(); - client.ctx().set_ca_file("test/root-ca.pem").unwrap(); + let client = server.client_with_root_ca(); client.connect(); } @@ -109,9 +108,8 @@ fn verify_untrusted_callback_override_bad() { #[test] fn verify_trusted_callback_override_ok() { let server = Server::builder().build(); + let mut client = server.client_with_root_ca(); - let mut client = server.client(); - client.ctx().set_ca_file("test/root-ca.pem").unwrap(); client .ctx() .set_verify_callback(SslVerifyMode::PEER, |_, x509| { @@ -125,11 +123,12 @@ fn verify_trusted_callback_override_ok() { #[test] fn verify_trusted_callback_override_bad() { let mut server = Server::builder(); + server.should_error(); + let server = server.build(); + let mut client = server.client_with_root_ca(); - let mut client = server.client(); - client.ctx().set_ca_file("test/root-ca.pem").unwrap(); client .ctx() .set_verify_callback(SslVerifyMode::PEER, |_, _| false); @@ -155,9 +154,8 @@ fn verify_callback_load_certs() { #[test] fn verify_trusted_get_error_ok() { let server = Server::builder().build(); + let mut client = server.client_with_root_ca(); - let mut client = server.client(); - client.ctx().set_ca_file("test/root-ca.pem").unwrap(); client .ctx() .set_verify_callback(SslVerifyMode::PEER, |_, x509| { @@ -697,9 +695,8 @@ fn add_extra_chain_cert() { #[test] fn verify_valid_hostname() { let server = Server::builder().build(); + let mut client = server.client_with_root_ca(); - let mut client = server.client(); - client.ctx().set_ca_file("test/root-ca.pem").unwrap(); client.ctx().set_verify(SslVerifyMode::PEER); let mut client = client.build().builder(); @@ -714,11 +711,12 @@ fn verify_valid_hostname() { #[test] fn verify_invalid_hostname() { let mut server = Server::builder(); + server.should_error(); + let server = server.build(); + let mut client = server.client_with_root_ca(); - let mut client = server.client(); - client.ctx().set_ca_file("test/root-ca.pem").unwrap(); client.ctx().set_verify(SslVerifyMode::PEER); let mut client = client.build().builder(); diff --git a/boring/src/ssl/test/private_key_method.rs b/boring/src/ssl/test/private_key_method.rs new file mode 100644 index 000000000..f711fccc0 --- /dev/null +++ b/boring/src/ssl/test/private_key_method.rs @@ -0,0 +1,282 @@ +use once_cell::sync::OnceCell; + +use super::server::{Builder, Server}; +use super::KEY; +use crate::hash::{Hasher, MessageDigest}; +use crate::pkey::PKey; +use crate::rsa::Padding; +use crate::sign::{RsaPssSaltlen, Signer}; +use crate::ssl::{ + ErrorCode, HandshakeError, PrivateKeyMethod, PrivateKeyMethodError, SslRef, + SslSignatureAlgorithm, +}; +use crate::x509::X509; +use std::cmp; +use std::io::{Read, Write}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; + +#[allow(clippy::type_complexity)] +pub(super) struct Method { + sign: Box< + dyn Fn( + &mut SslRef, + &[u8], + SslSignatureAlgorithm, + &mut [u8], + ) -> Result + + Send + + Sync + + 'static, + >, + decrypt: Box< + dyn Fn(&mut SslRef, &[u8], &mut [u8]) -> Result + + Send + + Sync + + 'static, + >, + complete: Box< + dyn Fn(&mut SslRef, &mut [u8]) -> Result + + Send + + Sync + + 'static, + >, +} + +impl Method { + pub(super) fn new() -> Self { + Self { + sign: Box::new(|_, _, _, _| unreachable!("called sign")), + decrypt: Box::new(|_, _, _| unreachable!("called decrypt")), + complete: Box::new(|_, _| unreachable!("called complete")), + } + } + + pub(super) fn sign( + mut self, + sign: impl Fn( + &mut SslRef, + &[u8], + SslSignatureAlgorithm, + &mut [u8], + ) -> Result + + Send + + Sync + + 'static, + ) -> Self { + self.sign = Box::new(sign); + + self + } + + #[allow(dead_code)] + pub(super) fn decrypt( + mut self, + decrypt: impl Fn(&mut SslRef, &[u8], &mut [u8]) -> Result + + Send + + Sync + + 'static, + ) -> Self { + self.decrypt = Box::new(decrypt); + + self + } + + pub(super) fn complete( + mut self, + complete: impl Fn(&mut SslRef, &mut [u8]) -> Result + + Send + + Sync + + 'static, + ) -> Self { + self.complete = Box::new(complete); + + self + } +} + +impl PrivateKeyMethod for Method { + fn sign( + &self, + ssl: &mut SslRef, + input: &[u8], + signature_algorithm: SslSignatureAlgorithm, + output: &mut [u8], + ) -> Result { + (self.sign)(ssl, input, signature_algorithm, output) + } + + fn decrypt( + &self, + ssl: &mut SslRef, + input: &[u8], + output: &mut [u8], + ) -> Result { + (self.decrypt)(ssl, input, output) + } + + fn complete( + &self, + ssl: &mut SslRef, + output: &mut [u8], + ) -> Result { + (self.complete)(ssl, output) + } +} + +fn builder_with_private_key_method(method: Method) -> Builder { + let mut builder = Server::builder(); + + builder.ctx().set_private_key_method(method); + + builder +} + +#[test] +fn test_sign_failure() { + let called_sign = Arc::new(AtomicBool::new(false)); + let called_sign_clone = called_sign.clone(); + + let mut builder = builder_with_private_key_method(Method::new().sign(move |_, _, _, _| { + called_sign_clone.store(true, Ordering::SeqCst); + + Err(PrivateKeyMethodError::FAILURE) + })); + + builder.err_cb(|error| { + let HandshakeError::Failure(mid_handshake) = error else { + panic!("should be Failure"); + }; + + assert_eq!(mid_handshake.error().code(), ErrorCode::SSL); + }); + + let server = builder.build(); + let client = server.client_with_root_ca(); + + client.connect_err(); + + assert!(called_sign.load(Ordering::SeqCst)); +} + +#[test] +fn test_sign_retry_complete_failure() { + let called_complete = Arc::new(AtomicUsize::new(0)); + let called_complete_clone = called_complete.clone(); + + let mut builder = builder_with_private_key_method( + Method::new() + .sign(|_, _, _, _| Err(PrivateKeyMethodError::RETRY)) + .complete(move |_, _| { + let old = called_complete_clone.fetch_add(1, Ordering::SeqCst); + + Err(if old == 0 { + PrivateKeyMethodError::RETRY + } else { + PrivateKeyMethodError::FAILURE + }) + }), + ); + + builder.err_cb(|error| { + let HandshakeError::WouldBlock(mid_handshake) = error else { + panic!("should be WouldBlock"); + }; + + assert!(mid_handshake.error().would_block()); + assert_eq!( + mid_handshake.error().code(), + ErrorCode::WANT_PRIVATE_KEY_OPERATION + ); + + let HandshakeError::WouldBlock(mid_handshake) = mid_handshake.handshake().unwrap_err() else { + panic!("should be WouldBlock"); + }; + + assert_eq!( + mid_handshake.error().code(), + ErrorCode::WANT_PRIVATE_KEY_OPERATION + ); + + let HandshakeError::Failure(mid_handshake) = mid_handshake.handshake().unwrap_err() else { + panic!("should be Failure"); + }; + + assert_eq!(mid_handshake.error().code(), ErrorCode::SSL); + }); + + let server = builder.build(); + let client = server.client_with_root_ca(); + + client.connect_err(); + + assert_eq!(called_complete.load(Ordering::SeqCst), 2); +} + +#[test] +fn test_sign_ok() { + let server = builder_with_private_key_method(Method::new().sign( + |_, input, signature_algorithm, output| { + assert_eq!( + signature_algorithm, + SslSignatureAlgorithm::RSA_PSS_RSAE_SHA256, + ); + + Ok(sign_with_default_config(input, output)) + }, + )) + .build(); + + let client = server.client_with_root_ca(); + + client.connect(); +} + +#[test] +fn test_sign_retry_complete_ok() { + let input_cell = Arc::new(OnceCell::new()); + let input_cell_clone = input_cell.clone(); + + let mut builder = builder_with_private_key_method( + Method::new() + .sign(move |_, input, _, _| { + input_cell.set(input.to_owned()).unwrap(); + + Err(PrivateKeyMethodError::RETRY) + }) + .complete(move |_, output| { + let input = input_cell_clone.get().unwrap(); + + Ok(sign_with_default_config(input, output)) + }), + ); + + builder.err_cb(|error| { + let HandshakeError::WouldBlock(mid_handshake) = error else { + panic!("should be WouldBlock"); + }; + + let mut socket = mid_handshake.handshake().unwrap(); + + socket.write_all(&[0]).unwrap(); + }); + + let server = builder.build(); + let client = server.client_with_root_ca(); + + client.connect(); +} + +fn sign_with_default_config(input: &[u8], output: &mut [u8]) -> usize { + let pkey = PKey::private_key_from_pem(KEY).unwrap(); + let mut signer = Signer::new(MessageDigest::sha256(), &pkey).unwrap(); + + signer.set_rsa_padding(Padding::PKCS1_PSS).unwrap(); + signer + .set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH) + .unwrap(); + + signer.update(input).unwrap(); + + signer.sign(output).unwrap() +} diff --git a/boring/src/ssl/test/server.rs b/boring/src/ssl/test/server.rs index 41677e576..7d79cd754 100644 --- a/boring/src/ssl/test/server.rs +++ b/boring/src/ssl/test/server.rs @@ -2,7 +2,10 @@ use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::thread::{self, JoinHandle}; -use crate::ssl::{Ssl, SslContext, SslContextBuilder, SslFiletype, SslMethod, SslRef, SslStream}; +use crate::ssl::{ + HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslContextBuilder, SslFiletype, + SslMethod, SslRef, SslStream, +}; pub struct Server { handle: Option>, @@ -28,6 +31,7 @@ impl Server { ctx, ssl_cb: Box::new(|_| {}), io_cb: Box::new(|_| {}), + err_cb: Box::new(|_| {}), should_error: false, } } @@ -39,6 +43,14 @@ impl Server { } } + pub fn client_with_root_ca(&self) -> ClientBuilder { + let mut client = self.client(); + + client.ctx().set_ca_file("test/root-ca.pem").unwrap(); + + client + } + pub fn connect_tcp(&self) -> TcpStream { TcpStream::connect(self.addr).unwrap() } @@ -48,6 +60,7 @@ pub struct Builder { ctx: SslContextBuilder, ssl_cb: Box, io_cb: Box) + Send>, + err_cb: Box) + Send>, should_error: bool, } @@ -70,6 +83,12 @@ impl Builder { self.io_cb = Box::new(cb); } + pub fn err_cb(&mut self, cb: impl FnMut(HandshakeError) + Send + 'static) { + self.should_error(); + + self.err_cb = Box::new(cb); + } + pub fn should_error(&mut self) { self.should_error = true; } @@ -80,6 +99,7 @@ impl Builder { let addr = socket.local_addr().unwrap(); let mut ssl_cb = self.ssl_cb; let mut io_cb = self.io_cb; + let mut err_cb = self.err_cb; let should_error = self.should_error; let handle = thread::spawn(move || { @@ -88,7 +108,7 @@ impl Builder { ssl_cb(&mut ssl); let r = ssl.accept(socket); if should_error { - r.unwrap_err(); + err_cb(r.unwrap_err()); } else { let mut socket = r.unwrap(); socket.write_all(&[0]).unwrap(); @@ -124,8 +144,8 @@ impl ClientBuilder { self.build().builder().connect() } - pub fn connect_err(self) { - self.build().builder().connect_err(); + pub fn connect_err(self) -> HandshakeError { + self.build().builder().connect_err() } } @@ -160,8 +180,9 @@ impl ClientSslBuilder { s } - pub fn connect_err(self) { + pub fn connect_err(self) -> HandshakeError { let socket = TcpStream::connect(self.addr).unwrap(); - self.ssl.connect(socket).unwrap_err(); + + self.ssl.setup_connect(socket).handshake().unwrap_err() } } From c445931c97f02459b5785098099d76fa10eacba6 Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Fri, 28 Jul 2023 15:15:11 +0200 Subject: [PATCH 09/12] Introduce async callbacks We introduce tokio_boring::SslContextBuilderExt, with 2 methods: * set_async_select_certificate_callback * set_async_private_key_method --- boring/src/ssl/mod.rs | 13 + boring/src/ssl/test/private_key_method.rs | 3 +- tokio-boring/Cargo.toml | 1 + tokio-boring/src/async_callbacks.rs | 263 ++++++++++++++++++ tokio-boring/src/bridge.rs | 3 +- tokio-boring/src/lib.rs | 17 ++ .../tests/async_private_key_method.rs | 187 +++++++++++++ .../tests/async_select_certificate.rs | 96 +++++++ 8 files changed, 580 insertions(+), 3 deletions(-) create mode 100644 tokio-boring/src/async_callbacks.rs create mode 100644 tokio-boring/tests/async_private_key_method.rs create mode 100644 tokio-boring/tests/async_select_certificate.rs diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index c0a6c3413..d4f763ff7 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -482,6 +482,9 @@ pub struct SelectCertError(ffi::ssl_select_cert_result_t); impl SelectCertError { /// A fatal error occured and the handshake should be terminated. pub const ERROR: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_error); + + /// The operation could not be completed and should be retried later. + pub const RETRY: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_retry); } /// Extension types, to be used with `ClientHello::get_extension`. @@ -3280,6 +3283,11 @@ impl MidHandshakeSslStream { self.stream.ssl() } + /// Returns a mutable reference to the `Ssl` of the stream. + pub fn ssl_mut(&mut self) -> &mut SslRef { + self.stream.ssl_mut() + } + /// Returns the underlying error which interrupted this handshake. pub fn error(&self) -> &Error { &self.error @@ -3585,6 +3593,11 @@ impl SslStream { pub fn ssl(&self) -> &SslRef { &self.ssl } + + /// Returns a mutable reference to the `Ssl` object associated with this stream. + pub fn ssl_mut(&mut self) -> &mut SslRef { + &mut self.ssl + } } impl Read for SslStream { diff --git a/boring/src/ssl/test/private_key_method.rs b/boring/src/ssl/test/private_key_method.rs index f711fccc0..722ed8f11 100644 --- a/boring/src/ssl/test/private_key_method.rs +++ b/boring/src/ssl/test/private_key_method.rs @@ -189,7 +189,8 @@ fn test_sign_retry_complete_failure() { ErrorCode::WANT_PRIVATE_KEY_OPERATION ); - let HandshakeError::WouldBlock(mid_handshake) = mid_handshake.handshake().unwrap_err() else { + let HandshakeError::WouldBlock(mid_handshake) = mid_handshake.handshake().unwrap_err() + else { panic!("should be WouldBlock"); }; diff --git a/tokio-boring/Cargo.toml b/tokio-boring/Cargo.toml index 73d5c8ec2..509ab6f73 100644 --- a/tokio-boring/Cargo.toml +++ b/tokio-boring/Cargo.toml @@ -39,6 +39,7 @@ no-patches = ["boring/no-patches"] [dependencies] boring = { workspace = true } boring-sys = { workspace = true } +once_cell = { workspace = true } tokio = { workspace = true } [dev-dependencies] diff --git a/tokio-boring/src/async_callbacks.rs b/tokio-boring/src/async_callbacks.rs new file mode 100644 index 000000000..ee658ea72 --- /dev/null +++ b/tokio-boring/src/async_callbacks.rs @@ -0,0 +1,263 @@ +use boring::ex_data::Index; +use boring::ssl::{self, ClientHello, PrivateKeyMethod, Ssl, SslContextBuilder}; +use once_cell::sync::Lazy; +use std::future::Future; +use std::pin::Pin; +use std::task::{ready, Context, Poll, Waker}; + +/// The type of futures to pass to [`SslContextBuilderExt::set_async_select_certificate_callback`]. +pub type BoxSelectCertFuture = ExDataFuture>; + +/// The type of callbacks returned by [`BoxSelectCertFuture`] methods. +pub type BoxSelectCertFinish = Box) -> Result<(), AsyncSelectCertError>>; + +/// The type of futures returned by [`AsyncPrivateKeyMethod`] methods. +pub type BoxPrivateKeyMethodFuture = + ExDataFuture>; + +/// The type of callbacks returned by [`BoxPrivateKeyMethodFuture`]. +pub type BoxPrivateKeyMethodFinish = + Box Result>; + +/// Convenience alias for futures stored in [`Ssl`] ex data by [`SslContextBuilderExt`] methods. +/// +/// Public for documentation purposes. +pub type ExDataFuture = Pin + Send + Sync>>; + +pub(crate) static TASK_WAKER_INDEX: Lazy>> = + Lazy::new(|| Ssl::new_ex_index().unwrap()); +pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy> = + Lazy::new(|| Ssl::new_ex_index().unwrap()); +pub(crate) static SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX: Lazy< + Index, +> = Lazy::new(|| Ssl::new_ex_index().unwrap()); + +/// Extensions to [`SslContextBuilder`]. +/// +/// This trait provides additional methods to use async callbacks with boring. +pub trait SslContextBuilderExt: private::Sealed { + /// Sets a callback that is called before most [`ClientHello`] processing + /// and before the decision whether to resume a session is made. The + /// callback may inspect the [`ClientHello`] and configure the connection. + /// + /// This method uses a function that returns a future whose output is + /// itself a closure that will be passed [`ClientHello`] to configure + /// the connection based on the computations done in the future. + /// + /// See [`SslContextBuilder::set_select_certificate_callback`] for the sync + /// setter of this callback. + fn set_async_select_certificate_callback(&mut self, callback: F) + where + F: Fn(&mut ClientHello<'_>) -> Result + + Send + + Sync + + 'static; + + /// Configures a custom private key method on the context. + /// + /// See [`AsyncPrivateKeyMethod`] for more details. + fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod); +} + +impl SslContextBuilderExt for SslContextBuilder { + fn set_async_select_certificate_callback(&mut self, callback: F) + where + F: Fn(&mut ClientHello<'_>) -> Result + + Send + + Sync + + 'static, + { + self.set_select_certificate_callback(move |mut client_hello| { + let fut_poll_result = with_ex_data_future( + &mut client_hello, + *SELECT_CERT_FUTURE_INDEX, + ClientHello::ssl_mut, + &callback, + ); + + let fut_result = match fut_poll_result { + Poll::Ready(fut_result) => fut_result, + Poll::Pending => return Err(ssl::SelectCertError::RETRY), + }; + + let finish = fut_result.or(Err(ssl::SelectCertError::ERROR))?; + + finish(client_hello).or(Err(ssl::SelectCertError::ERROR)) + }) + } + + fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod) { + self.set_private_key_method(AsyncPrivateKeyMethodBridge(Box::new(method))); + } +} + +/// A fatal error to be returned from async select certificate callbacks. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct AsyncSelectCertError; + +/// Describes async private key hooks. This is used to off-load signing +/// operations to a custom, potentially asynchronous, backend. Metadata about the +/// key such as the type and size are parsed out of the certificate. +/// +/// See [`PrivateKeyMethod`] for the sync version of those hooks. +/// +/// [`ssl_private_key_method_st`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#ssl_private_key_method_st +pub trait AsyncPrivateKeyMethod: Send + Sync + 'static { + /// Signs the message `input` using the specified signature algorithm. + /// + /// This method uses a function that returns a future whose output is + /// itself a closure that will be passed `ssl` and `output` + /// to finish writing the signature. + /// + /// See [`PrivateKeyMethod::sign`] for the sync version of this method. + fn sign( + &self, + ssl: &mut ssl::SslRef, + input: &[u8], + signature_algorithm: ssl::SslSignatureAlgorithm, + output: &mut [u8], + ) -> Result; + + /// Decrypts `input`. + /// + /// This method uses a function that returns a future whose output is + /// itself a closure that will be passed `ssl` and `output` + /// to finish decrypting the input. + /// + /// See [`PrivateKeyMethod::decrypt`] for the sync version of this method. + fn decrypt( + &self, + ssl: &mut ssl::SslRef, + input: &[u8], + output: &mut [u8], + ) -> Result; +} + +/// A fatal error to be returned from async private key methods. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct AsyncPrivateKeyMethodError; + +struct AsyncPrivateKeyMethodBridge(Box); + +impl PrivateKeyMethod for AsyncPrivateKeyMethodBridge { + fn sign( + &self, + ssl: &mut ssl::SslRef, + input: &[u8], + signature_algorithm: ssl::SslSignatureAlgorithm, + output: &mut [u8], + ) -> Result { + with_private_key_method(ssl, output, |ssl, output| { + ::sign(&*self.0, ssl, input, signature_algorithm, output) + }) + } + + fn decrypt( + &self, + ssl: &mut ssl::SslRef, + input: &[u8], + output: &mut [u8], + ) -> Result { + with_private_key_method(ssl, output, |ssl, output| { + ::decrypt(&*self.0, ssl, input, output) + }) + } + + fn complete( + &self, + ssl: &mut ssl::SslRef, + output: &mut [u8], + ) -> Result { + with_private_key_method(ssl, output, |_, _| { + // This should never be reached, if it does, that's a bug on boring's side, + // which called `complete` without having been returned to with a pending + // future from `sign` or `decrypt`. + + if cfg!(debug_assertions) { + panic!("BUG: boring called complete without a pending operation"); + } + + Err(AsyncPrivateKeyMethodError) + }) + } +} + +/// Creates and drives a private key method future. +/// +/// This is a convenience function for the three methods of impl `PrivateKeyMethod`` +/// for `dyn AsyncPrivateKeyMethod`. It relies on [`with_ex_data_future`] to +/// drive the future and then immediately calls the final [`BoxPrivateKeyMethodFinish`] +/// when the future is ready. +fn with_private_key_method( + ssl: &mut ssl::SslRef, + output: &mut [u8], + create_fut: impl FnOnce( + &mut ssl::SslRef, + &mut [u8], + ) -> Result, +) -> Result { + let fut_poll_result = with_ex_data_future( + ssl, + *SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX, + |ssl| ssl, + |ssl| create_fut(ssl, output), + ); + + let fut_result = match fut_poll_result { + Poll::Ready(fut_result) => fut_result, + Poll::Pending => return Err(ssl::PrivateKeyMethodError::RETRY), + }; + + let finish = fut_result.or(Err(ssl::PrivateKeyMethodError::FAILURE))?; + + finish(ssl, output).or(Err(ssl::PrivateKeyMethodError::FAILURE)) +} + +/// Creates and drives a future stored in `ssl_handle`'s `Ssl` at ex data index `index`. +/// +/// This function won't even bother storing the future in `index` if the future +/// created by `create_fut` returns `Poll::Ready(_)` on the first poll call. +fn with_ex_data_future( + ssl_handle: &mut H, + index: Index>>, + get_ssl_mut: impl Fn(&mut H) -> &mut ssl::SslRef, + create_fut: impl FnOnce(&mut H) -> Result>, E>, +) -> Poll> { + let ssl = get_ssl_mut(ssl_handle); + let waker = ssl + .ex_data(*TASK_WAKER_INDEX) + .cloned() + .flatten() + .expect("task waker should be set"); + + let mut ctx = Context::from_waker(&waker); + + match ssl.ex_data_mut(index) { + Some(fut) => { + let fut_result = ready!(fut.as_mut().poll(&mut ctx)); + + // NOTE(nox): For memory usage concerns, maybe we should implement + // a way to remove the stored future from the `Ssl` value here? + + Poll::Ready(fut_result) + } + None => { + let mut fut = create_fut(ssl_handle)?; + + match fut.as_mut().poll(&mut ctx) { + Poll::Ready(fut_result) => Poll::Ready(fut_result), + Poll::Pending => { + get_ssl_mut(ssl_handle).set_ex_data(index, fut); + + Poll::Pending + } + } + } + } +} + +mod private { + pub trait Sealed {} +} + +impl private::Sealed for SslContextBuilder {} diff --git a/tokio-boring/src/bridge.rs b/tokio-boring/src/bridge.rs index 9de3fc224..62ed7729f 100644 --- a/tokio-boring/src/bridge.rs +++ b/tokio-boring/src/bridge.rs @@ -1,5 +1,4 @@ //! Bridge between sync IO traits and async tokio IO traits. - use std::fmt; use std::io; use std::pin::Pin; @@ -35,7 +34,7 @@ impl AsyncStreamBridge { F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R, { let mut ctx = - Context::from_waker(self.waker.as_ref().expect("missing task context pointer")); + Context::from_waker(self.waker.as_ref().expect("BUG: missing waker in bridge")); f(&mut ctx, Pin::new(&mut self.stream)) } diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index f437ee262..a8ab50ad6 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -27,8 +27,15 @@ use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +mod async_callbacks; mod bridge; +use self::async_callbacks::TASK_WAKER_INDEX; +pub use self::async_callbacks::{ + AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, + BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish, BoxSelectCertFuture, + ExDataFuture, SslContextBuilderExt, +}; use self::bridge::AsyncStreamBridge; /// Asynchronously performs a client-side TLS handshake over the provided stream. @@ -90,6 +97,11 @@ impl SslStream { self.0.ssl() } + /// Returns a mutable reference to the `Ssl` object associated with this stream. + pub fn ssl_mut(&mut self) -> &mut SslRef { + self.0.ssl_mut() + } + /// Returns a shared reference to the underlying stream. pub fn get_ref(&self) -> &S { &self.0.get_ref().stream @@ -285,15 +297,20 @@ where let mut mid_handshake = self.0.take().expect("future polled after completion"); mid_handshake.get_mut().set_waker(Some(ctx)); + mid_handshake + .ssl_mut() + .set_ex_data(*TASK_WAKER_INDEX, Some(ctx.waker().clone())); match mid_handshake.handshake() { Ok(mut stream) => { stream.get_mut().set_waker(None); + stream.ssl_mut().set_ex_data(*TASK_WAKER_INDEX, None); Poll::Ready(Ok(SslStream(stream))) } Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => { mid_handshake.get_mut().set_waker(None); + mid_handshake.ssl_mut().set_ex_data(*TASK_WAKER_INDEX, None); self.0 = Some(mid_handshake); diff --git a/tokio-boring/tests/async_private_key_method.rs b/tokio-boring/tests/async_private_key_method.rs new file mode 100644 index 000000000..b39ef7d38 --- /dev/null +++ b/tokio-boring/tests/async_private_key_method.rs @@ -0,0 +1,187 @@ +use boring::hash::MessageDigest; +use boring::pkey::PKey; +use boring::rsa::Padding; +use boring::sign::{RsaPssSaltlen, Signer}; +use boring::ssl::{SslRef, SslSignatureAlgorithm}; +use futures::future; +use tokio::task::yield_now; +use tokio_boring::{ + AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, BoxPrivateKeyMethodFuture, + SslContextBuilderExt, +}; + +mod common; + +use self::common::{connect, create_server, with_trivial_client_server_exchange}; + +#[allow(clippy::type_complexity)] +struct Method { + sign: Box< + dyn Fn( + &mut SslRef, + &[u8], + SslSignatureAlgorithm, + &mut [u8], + ) -> Result + + Send + + Sync + + 'static, + >, + decrypt: Box< + dyn Fn( + &mut SslRef, + &[u8], + &mut [u8], + ) -> Result + + Send + + Sync + + 'static, + >, +} + +impl Method { + fn new() -> Self { + Self { + sign: Box::new(|_, _, _, _| unreachable!("called sign")), + decrypt: Box::new(|_, _, _| unreachable!("called decrypt")), + } + } + + fn sign( + mut self, + sign: impl Fn( + &mut SslRef, + &[u8], + SslSignatureAlgorithm, + &mut [u8], + ) -> Result + + Send + + Sync + + 'static, + ) -> Self { + self.sign = Box::new(sign); + + self + } + + #[allow(dead_code)] + fn decrypt( + mut self, + decrypt: impl Fn( + &mut SslRef, + &[u8], + &mut [u8], + ) -> Result + + Send + + Sync + + 'static, + ) -> Self { + self.decrypt = Box::new(decrypt); + + self + } +} + +impl AsyncPrivateKeyMethod for Method { + fn sign( + &self, + ssl: &mut SslRef, + input: &[u8], + signature_algorithm: SslSignatureAlgorithm, + output: &mut [u8], + ) -> Result { + (self.sign)(ssl, input, signature_algorithm, output) + } + + fn decrypt( + &self, + ssl: &mut SslRef, + input: &[u8], + output: &mut [u8], + ) -> Result { + (self.decrypt)(ssl, input, output) + } +} + +#[tokio::test] +async fn test_sign_failure() { + with_async_private_key_method_error( + Method::new().sign(|_, _, _, _| Err(AsyncPrivateKeyMethodError)), + ) + .await; +} + +#[tokio::test] +async fn test_sign_future_failure() { + with_async_private_key_method_error( + Method::new().sign(|_, _, _, _| Ok(Box::pin(async { Err(AsyncPrivateKeyMethodError) }))), + ) + .await; +} + +#[tokio::test] +async fn test_sign_future_yield_failure() { + with_async_private_key_method_error(Method::new().sign(|_, _, _, _| { + Ok(Box::pin(async { + yield_now().await; + + Err(AsyncPrivateKeyMethodError) + })) + })) + .await; +} + +#[tokio::test] +async fn test_sign_ok() { + with_trivial_client_server_exchange(|builder| { + builder.set_async_private_key_method(Method::new().sign( + |_, input, signature_algorithm, _| { + assert_eq!( + signature_algorithm, + SslSignatureAlgorithm::RSA_PSS_RSAE_SHA256, + ); + + let input = input.to_owned(); + + Ok(Box::pin(async move { + Ok(Box::new(move |_: &mut SslRef, output: &mut [u8]| { + Ok(sign_with_default_config(&input, output)) + }) as Box<_>) + })) + }, + )); + }) + .await; +} + +fn sign_with_default_config(input: &[u8], output: &mut [u8]) -> usize { + let pkey = PKey::private_key_from_pem(include_bytes!("key.pem")).unwrap(); + let mut signer = Signer::new(MessageDigest::sha256(), &pkey).unwrap(); + + signer.set_rsa_padding(Padding::PKCS1_PSS).unwrap(); + signer + .set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH) + .unwrap(); + + signer.update(input).unwrap(); + + signer.sign(output).unwrap() +} + +async fn with_async_private_key_method_error(method: Method) { + let (stream, addr) = create_server(move |builder| { + builder.set_async_private_key_method(method); + }); + + let server = async { + let _err = stream.await.unwrap_err(); + }; + + let client = async { + let _err = connect(addr, |builder| builder.set_ca_file("tests/cert.pem")) + .await + .unwrap_err(); + }; + + future::join(server, client).await; +} diff --git a/tokio-boring/tests/async_select_certificate.rs b/tokio-boring/tests/async_select_certificate.rs new file mode 100644 index 000000000..e77e1167a --- /dev/null +++ b/tokio-boring/tests/async_select_certificate.rs @@ -0,0 +1,96 @@ +use boring::ssl::ClientHello; +use futures::future; +use tokio::task::yield_now; +use tokio_boring::{ + AsyncSelectCertError, BoxSelectCertFinish, BoxSelectCertFuture, SslContextBuilderExt, +}; + +mod common; + +use self::common::{connect, create_server, with_trivial_client_server_exchange}; + +#[tokio::test] +async fn test_async_select_certificate_callback_trivial() { + with_trivial_client_server_exchange(|builder| { + builder.set_async_select_certificate_callback(|_| { + Ok(Box::pin(async { + Ok(Box::new(|_: ClientHello<'_>| Ok(())) as BoxSelectCertFinish) + })) + }); + }) + .await; +} + +#[tokio::test] +async fn test_async_select_certificate_callback_yield() { + with_trivial_client_server_exchange(|builder| { + builder.set_async_select_certificate_callback(|_| { + Ok(Box::pin(async { + yield_now().await; + + Ok(Box::new(|_: ClientHello<'_>| Ok(())) as BoxSelectCertFinish) + })) + }); + }) + .await; +} + +#[tokio::test] +async fn test_async_select_certificate_callback_return_error() { + with_async_select_certificate_callback_error(|_| Err(AsyncSelectCertError)).await; +} + +#[tokio::test] +async fn test_async_select_certificate_callback_future_error() { + with_async_select_certificate_callback_error(|_| { + Ok(Box::pin(async move { Err(AsyncSelectCertError) })) + }) + .await; +} + +#[tokio::test] +async fn test_async_select_certificate_callback_future_yield_error() { + with_async_select_certificate_callback_error(|_| { + Ok(Box::pin(async move { + yield_now().await; + + Err(AsyncSelectCertError) + })) + }) + .await; +} + +#[tokio::test] +async fn test_async_select_certificate_callback_finish_error() { + with_async_select_certificate_callback_error(|_| { + Ok(Box::pin(async move { + yield_now().await; + + Ok(Box::new(|_: ClientHello<'_>| Err(AsyncSelectCertError)) as BoxSelectCertFinish) + })) + }) + .await; +} + +async fn with_async_select_certificate_callback_error( + callback: impl Fn(&mut ClientHello<'_>) -> Result + + Send + + Sync + + 'static, +) { + let (stream, addr) = create_server(|builder| { + builder.set_async_select_certificate_callback(callback); + }); + + let server = async { + let _err = stream.await.unwrap_err(); + }; + + let client = async { + let _err = connect(addr, |builder| builder.set_ca_file("tests/cert.pem")) + .await + .unwrap_err(); + }; + + future::join(server, client).await; +} From 7d09f88914d92837f80c54d37e3a315c6452f1bb Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Fri, 6 Oct 2023 13:05:12 +0200 Subject: [PATCH 10/12] Introduce Ssl::set_certificate --- boring/src/ssl/mod.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index d4f763ff7..efa2c8ca2 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -3258,6 +3258,19 @@ impl SslRef { pub fn set_mtu(&mut self, mtu: u32) -> Result<(), ErrorStack> { unsafe { cvt(ffi::SSL_set_mtu(self.as_ptr(), mtu as c_uint) as c_int).map(|_| ()) } } + + /// Sets the certificate. + /// + /// This corresponds to [`SSL_use_certificate`]. + /// + /// [`SSL_use_certificate`]: https://www.openssl.org/docs/man1.1.1/man3/SSL_use_certificate.html + pub fn set_certificate(&mut self, cert: &X509Ref) -> Result<(), ErrorStack> { + unsafe { + cvt(ffi::SSL_use_certificate(self.as_ptr(), cert.as_ptr()))?; + } + + Ok(()) + } } /// An SSL stream midway through the handshake process. From 94d8cd7ab89f1dbc3cf67ce5e17ad94ee651de95 Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Fri, 6 Oct 2023 13:25:31 +0200 Subject: [PATCH 11/12] Introduce SslSignatureAlgorithm::RSA_PKCS1_MD5_SHA1 --- boring/src/ssl/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index efa2c8ca2..724d78232 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -597,6 +597,9 @@ impl SslSignatureAlgorithm { pub const RSA_PKCS1_SHA512: SslSignatureAlgorithm = SslSignatureAlgorithm(ffi::SSL_SIGN_RSA_PKCS1_SHA512 as _); + pub const RSA_PKCS1_MD5_SHA1: SslSignatureAlgorithm = + SslSignatureAlgorithm(ffi::SSL_SIGN_RSA_PKCS1_MD5_SHA1 as _); + pub const ECDSA_SHA1: SslSignatureAlgorithm = SslSignatureAlgorithm(ffi::SSL_SIGN_ECDSA_SHA1 as _); From ddfeb97f774261e2368a992706fa3dff1a297855 Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Mon, 9 Oct 2023 11:55:35 +0200 Subject: [PATCH 12/12] Remove futures from ex data slots once they resolve --- tokio-boring/src/async_callbacks.rs | 32 +++++++++++++---------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/tokio-boring/src/async_callbacks.rs b/tokio-boring/src/async_callbacks.rs index ee658ea72..b12ad2b27 100644 --- a/tokio-boring/src/async_callbacks.rs +++ b/tokio-boring/src/async_callbacks.rs @@ -26,10 +26,10 @@ pub type ExDataFuture = Pin + Send + Sync>>; pub(crate) static TASK_WAKER_INDEX: Lazy>> = Lazy::new(|| Ssl::new_ex_index().unwrap()); -pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy> = +pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy>> = Lazy::new(|| Ssl::new_ex_index().unwrap()); pub(crate) static SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX: Lazy< - Index, + Index>, > = Lazy::new(|| Ssl::new_ex_index().unwrap()); /// Extensions to [`SslContextBuilder`]. @@ -219,7 +219,7 @@ fn with_private_key_method( /// created by `create_fut` returns `Poll::Ready(_)` on the first poll call. fn with_ex_data_future( ssl_handle: &mut H, - index: Index>>, + index: Index>>>, get_ssl_mut: impl Fn(&mut H) -> &mut ssl::SslRef, create_fut: impl FnOnce(&mut H) -> Result>, E>, ) -> Poll> { @@ -232,25 +232,21 @@ fn with_ex_data_future( let mut ctx = Context::from_waker(&waker); - match ssl.ex_data_mut(index) { - Some(fut) => { - let fut_result = ready!(fut.as_mut().poll(&mut ctx)); + if let Some(data @ Some(_)) = ssl.ex_data_mut(index) { + let fut_result = ready!(data.as_mut().unwrap().as_mut().poll(&mut ctx)); - // NOTE(nox): For memory usage concerns, maybe we should implement - // a way to remove the stored future from the `Ssl` value here? + *data = None; - Poll::Ready(fut_result) - } - None => { - let mut fut = create_fut(ssl_handle)?; + Poll::Ready(fut_result) + } else { + let mut fut = create_fut(ssl_handle)?; - match fut.as_mut().poll(&mut ctx) { - Poll::Ready(fut_result) => Poll::Ready(fut_result), - Poll::Pending => { - get_ssl_mut(ssl_handle).set_ex_data(index, fut); + match fut.as_mut().poll(&mut ctx) { + Poll::Ready(fut_result) => Poll::Ready(fut_result), + Poll::Pending => { + get_ssl_mut(ssl_handle).set_ex_data(index, Some(fut)); - Poll::Pending - } + Poll::Pending } } }