Skip to content

Commit dad8312

Browse files
committed
Introduce dynamic TLS resolvers.
This commit introduces the ability to dynamically select a TLS configuration based on the client's TLS hello. Added `Authority::set_port()`. Various `Config` structures for listeners removed. `UdsListener` is now `UnixListener`. `Bindable` removed in favor of new `Bind`. `Connection` requires `AsyncRead + AsyncWrite` again The `Debug` impl for `Endpoint` displays the underlying address in plaintext. `Listener` must be `Sized`. `tls` listener moved to `tls::TlsListener` The preview `quic` listener no longer implements `Listener`. All built-in listeners now implement `Bind<&Rocket>`. Clarified docs for `mtls::Certificate` guard. No reexporitng rustls from `tls`. Added `TlsConfig::server_config()`. Added some future helpers: `race()` and `race_io()`. Fix an issue where the logger wouldn't respect a configuration during error printing. Added Rocket::launch_with(), launch_on(), bind_launch(). Added a default client.pem to the TLS example. Revamped the testbench. Added tests for TLS resolvers, MTLS, listener failure output. TODO: clippy. TODO: UDS testing. Resolves #2730. Resolves #2363. Closes #2748. Closes #2683. Closes #2577.
1 parent 280fda4 commit dad8312

File tree

5 files changed

+175
-86
lines changed

5 files changed

+175
-86
lines changed

core/lib/src/error.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -179,16 +179,16 @@ impl Error {
179179
match self.kind() {
180180
ErrorKind::Bind(ref a, ref e) => {
181181
if let Some(e) = e.downcast_ref::<Self>() {
182-
e.pretty_print();
182+
e.pretty_print()
183183
} else {
184184
match a {
185185
Some(a) => error!("Binding to {} failed.", a.primary().underline()),
186186
None => error!("Binding to network interface failed."),
187187
}
188-
}
189188

190-
info_!("{}", e);
191-
"aborting due to bind error"
189+
info_!("{}", e);
190+
"aborting due to bind error"
191+
}
192192
}
193193
ErrorKind::Io(ref e) => {
194194
error!("Rocket failed to launch due to an I/O error.");

core/lib/src/listener/default.rs

+75-32
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
1+
use core::fmt;
2+
13
use serde::Deserialize;
2-
use tokio_util::either::{Either, Either::{Left, Right}};
3-
use futures::TryFutureExt;
4+
use tokio_util::either::Either::{Left, Right};
5+
use either::Either;
46

5-
use crate::error::ErrorKind;
67
use crate::{Ignite, Rocket};
78
use crate::listener::{Bind, Endpoint, tcp::TcpListener};
89

910
#[cfg(unix)] use crate::listener::unix::UnixListener;
1011
#[cfg(feature = "tls")] use crate::tls::{TlsListener, TlsConfig};
1112

1213
mod private {
13-
use super::{Either, TcpListener};
14+
use tokio_util::either::Either;
1415

1516
#[cfg(feature = "tls")] pub type TlsListener<T> = super::TlsListener<T>;
1617
#[cfg(not(feature = "tls"))] pub type TlsListener<T> = T;
1718
#[cfg(unix)] pub type UnixListener = super::UnixListener;
1819
#[cfg(not(unix))] pub type UnixListener = super::TcpListener;
1920

2021
pub type Listener = Either<
21-
Either<TlsListener<TcpListener>, TlsListener<UnixListener>>,
22-
Either<TcpListener, UnixListener>,
22+
Either<TlsListener<super::TcpListener>, TlsListener<UnixListener>>,
23+
Either<super::TcpListener, UnixListener>,
2324
>;
2425
}
2526

@@ -33,48 +34,90 @@ struct Config {
3334

3435
pub type DefaultListener = private::Listener;
3536

37+
#[derive(Debug)]
38+
pub enum Error {
39+
Config(figment::Error),
40+
Io(std::io::Error),
41+
Unsupported(Endpoint),
42+
#[cfg(feature = "tls")]
43+
Tls(crate::tls::Error),
44+
}
45+
46+
impl From<figment::Error> for Error {
47+
fn from(value: figment::Error) -> Self {
48+
Error::Config(value)
49+
}
50+
}
51+
52+
impl From<std::io::Error> for Error {
53+
fn from(value: std::io::Error) -> Self {
54+
Error::Io(value)
55+
}
56+
}
57+
58+
#[cfg(feature = "tls")]
59+
impl From<crate::tls::Error> for Error {
60+
fn from(value: crate::tls::Error) -> Self {
61+
Error::Tls(value)
62+
}
63+
}
64+
65+
impl From<Either<figment::Error, std::io::Error>> for Error {
66+
fn from(value: Either<figment::Error, std::io::Error>) -> Self {
67+
value.either(Error::Config, Error::Io)
68+
}
69+
}
70+
71+
impl fmt::Display for Error {
72+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73+
match self {
74+
Error::Config(e) => e.fmt(f),
75+
Error::Io(e) => e.fmt(f),
76+
Error::Unsupported(e) => write!(f, "unsupported endpoint: {e:?}"),
77+
#[cfg(feature = "tls")]
78+
Error::Tls(error) => error.fmt(f),
79+
}
80+
}
81+
}
82+
83+
impl std::error::Error for Error {
84+
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
85+
match self {
86+
Error::Config(e) => Some(e),
87+
Error::Io(e) => Some(e),
88+
Error::Unsupported(_) => None,
89+
#[cfg(feature = "tls")]
90+
Error::Tls(e) => Some(e),
91+
}
92+
}
93+
}
94+
3695
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
37-
type Error = crate::Error;
96+
type Error = Error;
3897

3998
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
4099
let config: Config = rocket.figment().extract()?;
41100
match config.address {
42101
#[cfg(feature = "tls")]
43-
endpoint@Endpoint::Tcp(_) if config.tls.is_some() => {
44-
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket)
45-
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
46-
.await?;
47-
102+
Endpoint::Tcp(_) if config.tls.is_some() => {
103+
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket).await?;
48104
Ok(Left(Left(listener)))
49105
}
50-
endpoint@Endpoint::Tcp(_) => {
51-
let listener = <TcpListener as Bind<_>>::bind(rocket)
52-
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
53-
.await?;
54-
106+
Endpoint::Tcp(_) => {
107+
let listener = <TcpListener as Bind<_>>::bind(rocket).await?;
55108
Ok(Right(Left(listener)))
56109
}
57110
#[cfg(all(unix, feature = "tls"))]
58-
endpoint@Endpoint::Unix(_) if config.tls.is_some() => {
59-
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket)
60-
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
61-
.await?;
62-
111+
Endpoint::Unix(_) if config.tls.is_some() => {
112+
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket).await?;
63113
Ok(Left(Right(listener)))
64114
}
65115
#[cfg(unix)]
66-
endpoint@Endpoint::Unix(_) => {
67-
let listener = <UnixListener as Bind<_>>::bind(rocket)
68-
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
69-
.await?;
70-
116+
Endpoint::Unix(_) => {
117+
let listener = <UnixListener as Bind<_>>::bind(rocket).await?;
71118
Ok(Right(Right(listener)))
72119
}
73-
endpoint => {
74-
let msg = format!("unsupported bind endpoint: {endpoint}");
75-
let error = Box::<dyn std::error::Error + Send + Sync>::from(msg);
76-
Err(ErrorKind::Bind(Some(endpoint), error).into())
77-
}
120+
endpoint => Err(Error::Unsupported(endpoint)),
78121
}
79122
}
80123

core/lib/src/tls/resolver.rs

+33
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,39 @@ pub(crate) struct DynResolver(Arc<dyn Resolver>);
1515
pub struct Fairing<T: ?Sized>(PhantomData<T>);
1616

1717
/// A dynamic TLS configuration resolver.
18+
///
19+
/// # Example
20+
///
21+
/// This is an async trait. Implement it as follows:
22+
///
23+
/// ```rust
24+
/// # #[macro_use] extern crate rocket;
25+
/// use std::sync::Arc;
26+
/// use rocket::tls::{self, Resolver, TlsConfig, ClientHello, ServerConfig};
27+
/// use rocket::{Rocket, Build};
28+
///
29+
/// struct MyResolver(Arc<ServerConfig>);
30+
///
31+
/// #[rocket::async_trait]
32+
/// impl Resolver for MyResolver {
33+
/// async fn init(rocket: &Rocket<Build>) -> tls::Result<Self> {
34+
/// // This is equivalent to what the default resolver would do.
35+
/// let config: TlsConfig = rocket.figment().extract_inner("tls")?;
36+
/// let server_config = config.server_config().await?;
37+
/// Ok(MyResolver(Arc::new(server_config)))
38+
/// }
39+
///
40+
/// async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
41+
/// // return a `ServerConfig` based on `hello`; here we ignore it
42+
/// Some(self.0.clone())
43+
/// }
44+
/// }
45+
///
46+
/// #[launch]
47+
/// fn rocket() -> _ {
48+
/// rocket::build().attach(MyResolver::fairing())
49+
/// }
50+
/// ```
1851
#[crate::async_trait]
1952
pub trait Resolver: Send + Sync + 'static {
2053
async fn init(rocket: &Rocket<Build>) -> crate::tls::Result<Self> where Self: Sized {

testbench/src/main.rs

+49-26
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::process::ExitCode;
2+
use std::time::Duration;
23

34
use rocket::listener::unix::UnixListener;
45
use rocket::tokio::net::TcpListener;
@@ -163,9 +164,7 @@ fn tls_resolver() -> Result<()> {
163164
let server = spawn! {
164165
#[get("/count")]
165166
fn count(counter: &State<Arc<AtomicUsize>>) -> String {
166-
let count = counter.load(Ordering::Acquire);
167-
println!("{count}");
168-
count.to_string()
167+
counter.load(Ordering::Acquire).to_string()
169168
}
170169

171170
let counter = Arc::new(AtomicUsize::new(0));
@@ -329,8 +328,8 @@ fn tcp_unix_listener_fail() -> Result<()> {
329328
};
330329

331330
if let Err(Error::Liftoff(stdout, _)) = server {
332-
assert!(stdout.contains("expected valid TCP"));
333-
assert!(stdout.contains("for key default.address"));
331+
assert!(stdout.contains("expected valid TCP (ip) or unix (path)"));
332+
assert!(stdout.contains("default.address"));
334333
} else {
335334
panic!("unexpected result: {server:#?}");
336335
}
@@ -361,14 +360,17 @@ fn tcp_unix_listener_fail() -> Result<()> {
361360

362361
macro_rules! tests {
363362
($($f:ident),* $(,)?) => {[
364-
$(Test { name: stringify!($f), func: $f, }),*
363+
$(Test {
364+
name: stringify!($f),
365+
run: |_: ()| $f().map_err(|e| e.to_string()),
366+
}),*
365367
]};
366368
}
367369

368370
#[derive(Copy, Clone)]
369371
struct Test {
370372
name: &'static str,
371-
func: fn() -> Result<()>,
373+
run: fn(()) -> Result<(), String>,
372374
}
373375

374376
static TESTS: &[Test] = &tests![
@@ -377,37 +379,58 @@ static TESTS: &[Test] = &tests![
377379
];
378380

379381
fn main() -> ExitCode {
382+
procspawn::init();
383+
380384
let filter = std::env::args().nth(1).unwrap_or_default();
381385
let filtered = TESTS.into_iter().filter(|test| test.name.contains(&filter));
382386

383387
println!("running {}/{} tests", filtered.clone().count(), TESTS.len());
384-
let handles: Vec<_> = filtered
385-
.map(|test| (test, std::thread::spawn(move || {
386-
if let Err(e) = (test.func)() {
387-
println!("test {} ... {}\n {e}", test.name.bold(), "fail".red());
388-
return Err(e);
388+
let handles = filtered.map(|test| (test, std::thread::spawn(|| {
389+
let name = test.name;
390+
let start = std::time::SystemTime::now();
391+
let mut proc = procspawn::spawn((), test.run);
392+
let result = loop {
393+
match proc.join_timeout(Duration::from_secs(10)) {
394+
Err(e) if e.is_timeout() => {
395+
let elapsed = start.elapsed().unwrap().as_secs();
396+
println!("{name} has been running for {elapsed} seconds...");
397+
398+
if elapsed >= 30 {
399+
println!("{name} timeout");
400+
break Err(e);
401+
}
402+
},
403+
result => break result,
389404
}
405+
};
390406

391-
println!("test {} ... {}", test.name.bold(), "ok".green());
392-
Ok(())
393-
})))
394-
.collect();
395-
396-
let mut failure = false;
397-
for (test, handle) in handles {
398-
let result = handle.join();
399-
failure |= matches!(result, Err(_) | Ok(Err(_)));
400-
if result.is_err() {
401-
println!("test {} ... {}", test.name.bold(), "panic".red().underline());
407+
match result.as_ref().map_err(|e| e.panic_info()) {
408+
Ok(Ok(_)) => println!("test {name} ... {}", "ok".green()),
409+
Ok(Err(e)) => println!("test {name} ... {}\n {e}", "fail".red()),
410+
Err(Some(_)) => println!("test {name} ... {}", "panic".red().underline()),
411+
Err(None) => println!("test {name} ... {}", "error".magenta()),
402412
}
413+
414+
matches!(result, Ok(Ok(())))
415+
})));
416+
417+
let mut success = true;
418+
for (_, handle) in handles {
419+
success &= handle.join().unwrap_or(false);
403420
}
404421

405-
match failure {
406-
true => ExitCode::FAILURE,
407-
false => ExitCode::SUCCESS
422+
match success {
423+
true => ExitCode::SUCCESS,
424+
false => {
425+
println!("note: use `NOCAPTURE=1` to see test output");
426+
ExitCode::FAILURE
427+
}
408428
}
409429
}
410430

431+
// TODO: Implement an `UpdatingResolver`. Expose `SniResolver` and
432+
// `UpdatingResolver` in a `contrib` library or as part of `rocket`.
433+
//
411434
// struct UpdatingResolver {
412435
// timestamp: AtomicU64,
413436
// config: ArcSwap<ServerConfig>

0 commit comments

Comments
 (0)