Skip to content

Commit f029a0e

Browse files
committed
simpler impl of dynamically swapping tls creds
closes rwf2#2363
1 parent 97992b6 commit f029a0e

File tree

6 files changed

+117
-15
lines changed

6 files changed

+117
-15
lines changed
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
use std::{io, sync::Arc};
2+
3+
use rustls::{server::ClientHello, sign::{any_supported_type, CertifiedKey}};
4+
5+
use crate::tls::Config;
6+
use crate::tls::util::{load_certs, load_private_key};
7+
8+
pub(crate) struct CertResolver(Arc<CertifiedKey>);
9+
impl CertResolver {
10+
pub fn new<R>(config: &mut Config<R>) -> Result<Self, std::io::Error>
11+
where R: io::BufRead,
12+
{
13+
let certs = load_certs(&mut config.cert_chain)?;
14+
let private_key = load_private_key(&mut config.private_key)?;
15+
let key = any_supported_type(&private_key)
16+
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?;
17+
18+
Ok(
19+
Self(Arc::new(CertifiedKey::new(certs, key))))
20+
}
21+
}
22+
23+
impl rustls::server::ResolvesServerCert for CertResolver {
24+
fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
25+
Some(self.0.clone())
26+
}
27+
}

core/http/src/tls/listener.rs

+11-11
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ use std::task::{Context, Poll};
55
use std::future::Future;
66
use std::net::SocketAddr;
77

8+
use rustls::server::ResolvesServerCert;
89
use tokio::net::{TcpListener, TcpStream};
910
use tokio::io::{AsyncRead, AsyncWrite};
1011
use tokio_rustls::{Accept, TlsAcceptor, server::TlsStream as BareTlsStream};
1112

12-
use crate::tls::util::{load_certs, load_private_key, load_ca_certs};
13+
use crate::tls::util::load_ca_certs;
1314
use crate::listener::{Connection, Listener, Certificates};
15+
use crate::tls::CertResolver;
1416

1517
/// A TLS listener over TCP.
1618
pub struct TlsListener {
@@ -72,18 +74,12 @@ pub struct Config<R> {
7274
}
7375

7476
impl TlsListener {
75-
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>) -> io::Result<TlsListener>
76-
where R: io::BufRead
77+
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>, cert_resolver: Option<&Arc<dyn ResolvesServerCert>>) -> io::Result<TlsListener>
78+
where R: io::BufRead,
7779
{
7880
use rustls::server::{AllowAnyAuthenticatedClient, AllowAnyAnonymousOrAuthenticatedClient};
7981
use rustls::server::{NoClientAuth, ServerSessionMemoryCache, ServerConfig};
8082

81-
let cert_chain = load_certs(&mut c.cert_chain)
82-
.map_err(|e| io::Error::new(e.kind(), format!("bad TLS cert chain: {}", e)))?;
83-
84-
let key = load_private_key(&mut c.private_key)
85-
.map_err(|e| io::Error::new(e.kind(), format!("bad TLS private key: {}", e)))?;
86-
8783
let client_auth = match c.ca_certs {
8884
Some(ref mut ca_certs) => match load_ca_certs(ca_certs) {
8985
Ok(ca) if c.mandatory_mtls => AllowAnyAuthenticatedClient::new(ca).boxed(),
@@ -93,14 +89,18 @@ impl TlsListener {
9389
None => NoClientAuth::boxed(),
9490
};
9591

92+
let cert_resolver = match cert_resolver {
93+
Some(c) => c.clone(),
94+
None => Arc::new(CertResolver::new(&mut c)?),
95+
};
96+
9697
let mut tls_config = ServerConfig::builder()
9798
.with_cipher_suites(&c.ciphersuites)
9899
.with_safe_default_kx_groups()
99100
.with_safe_default_protocol_versions()
100101
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?
101102
.with_client_cert_verifier(client_auth)
102-
.with_single_cert(cert_chain, key)
103-
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?;
103+
.with_cert_resolver(cert_resolver);
104104

105105
tls_config.ignore_client_order = c.prefer_server_order;
106106

core/http/src/tls/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ mod listener;
33
#[cfg(feature = "mtls")]
44
pub mod mtls;
55

6+
pub(crate) mod certificate_resolver;
7+
68
pub use rustls;
79
pub use listener::{TlsListener, Config};
10+
pub(crate) use certificate_resolver::*;
811
pub mod util;

core/lib/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,5 @@ version_check = "0.9.1"
8686

8787
[dev-dependencies]
8888
figment = { version = "0.10", features = ["test"] }
89+
reqwest = { version = "0.11", features = ["blocking"] }
8990
pretty_assertions = "1"

core/lib/src/server.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use std::sync::Arc;
33
use std::time::Duration;
44
use std::pin::Pin;
55

6-
use yansi::Paint;
76
use tokio::sync::oneshot;
7+
use yansi::Paint;
88
use tokio::time::sleep;
99
use futures::stream::StreamExt;
1010
use futures::future::{FutureExt, Future, BoxFuture};
@@ -421,9 +421,12 @@ impl Rocket<Orbit> {
421421
if self.config.tls_enabled() {
422422
if let Some(ref config) = self.config.tls {
423423
use crate::http::tls::TlsListener;
424+
use crate::http::tls::rustls::server::ResolvesServerCert;
424425

425426
let conf = config.to_native_config().map_err(ErrorKind::Io)?;
426-
let l = TlsListener::bind(addr, conf).await.map_err(ErrorKind::Bind)?;
427+
let resolver = self.state::<Arc<dyn ResolvesServerCert>>();
428+
429+
let l = TlsListener::bind(addr, conf, resolver).await.map_err(ErrorKind::Bind)?;
427430
addr = l.local_addr().unwrap_or(addr);
428431
self.config.address = addr.ip();
429432
self.config.port = addr.port();

core/lib/tests/tls-config-from-source-1503.rs

+70-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ fn tls_config_from_source() {
1111
use rocket::config::{Config, TlsConfig};
1212
use rocket::figment::Figment;
1313

14-
let cert_path = relative!("examples/tls/private/cert.pem");
15-
let key_path = relative!("examples/tls/private/key.pem");
14+
let cert_path = relative!("../../examples/tls/private/cert.pem");
15+
let key_path = relative!("../../examples/tls/private/key.pem");
1616

1717
let rocket_config = Config {
1818
tls: Some(TlsConfig::from_paths(cert_path, key_path)),
@@ -24,3 +24,71 @@ fn tls_config_from_source() {
2424
assert_eq!(tls.certs().unwrap_left(), cert_path);
2525
assert_eq!(tls.key().unwrap_left(), key_path);
2626
}
27+
28+
#[test]
29+
fn tls_server_operation() {
30+
use std::io::Read;
31+
32+
use rocket::{get, routes};
33+
use rocket::config::{Config, TlsConfig};
34+
use rocket::figment::Figment;
35+
36+
let cert_path = relative!("../../examples/tls/private/rsa_sha256_cert.pem");
37+
let key_path = relative!("../../examples/tls/private/rsa_sha256_key.pem");
38+
let ca_cert_path = relative!("../../examples/tls/private/ca_cert.pem");
39+
40+
println!("{cert_path:?}");
41+
42+
let port = {
43+
let listener = std::net::TcpListener::bind(("127.0.0.1", 0)).expect("creating listener");
44+
listener.local_addr().expect("getting listener's port").port()
45+
};
46+
47+
let rocket_config = Config {
48+
port,
49+
tls: Some(TlsConfig::from_paths(cert_path, key_path)),
50+
..Default::default()
51+
};
52+
let config: Config = Figment::from(rocket_config).extract().expect("creating config");
53+
let (shutdown_signal_sender, mut shutdown_signal_receiver) = tokio::sync::mpsc::channel::<()>(1);
54+
55+
// Create a runtime in a separate thread for the server being tested
56+
let join_handle = std::thread::spawn(move || {
57+
let rt = tokio::runtime::Runtime::new().unwrap();
58+
59+
#[get("/hello")]
60+
fn tls_test_get() -> &'static str {
61+
"world"
62+
}
63+
64+
rt.block_on(async {
65+
let task_handle = tokio::spawn( async {
66+
rocket::custom(config)
67+
.mount("/", routes![tls_test_get])
68+
.launch().await.unwrap();
69+
});
70+
shutdown_signal_receiver.recv().await;
71+
task_handle.abort();
72+
});
73+
});
74+
75+
let request_url = format!("https://localhost:{}/hello", port);
76+
77+
// CA certificate is not loaded, so request should fail
78+
assert!(reqwest::blocking::get(&request_url).is_err());
79+
80+
// Load the CA certicate for use with test client
81+
let cert = {
82+
let mut buf = Vec::new();
83+
std::fs::File::open(ca_cert_path).expect("open ca_certs")
84+
.read_to_end(&mut buf).expect("read ca_certs");
85+
reqwest::Certificate::from_pem(&buf).expect("create certificate")
86+
};
87+
let client = reqwest::blocking::Client::builder().add_root_certificate(cert).build().expect("build client");
88+
89+
let response = client.get(&request_url).send().expect("https request");
90+
assert_eq!(&response.text().unwrap(), "world");
91+
92+
shutdown_signal_sender.blocking_send(()).expect("signal shutdown");
93+
join_handle.join().expect("join thread");
94+
}

0 commit comments

Comments
 (0)