Skip to content

Commit 7ffc133

Browse files
committed
Rebase fixes
1 parent 8677264 commit 7ffc133

File tree

2 files changed

+56
-26
lines changed

2 files changed

+56
-26
lines changed

sqlx-core/src/net/tls/tls_native_tls.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ pub async fn handshake<S: Socket>(
5757
if let (Some(cert_path), Some(key_path)) = (config.client_cert_path, config.client_key_path) {
5858
let cert_path = cert_path.data().await?;
5959
let key_path = key_path.data().await?;
60-
let identity =
61-
Identity::from_pkcs8(&cert_path, &key_path).map_err(|e| Error::Tls(e.into()))?;
60+
let identity = Identity::from_pkcs8(&cert_path, &key_path).map_err(Error::tls)?;
6261
builder.identity(identity);
6362
}
6463

sqlx-core/src/net/tls/tls_rustls.rs

+55-24
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use futures_util::future;
2-
use std::io;
3-
use std::io::{Cursor, Read, Write};
2+
use std::io::{self, BufReader, Cursor, Read, Write};
43
use std::sync::Arc;
54
use std::task::{Context, Poll};
65
use std::time::SystemTime;
@@ -48,7 +47,7 @@ impl<S: Socket> Socket for RustlsSocket<S> {
4847
match self.state.writer().write(buf) {
4948
// Returns a zero-length write when the buffer is full.
5049
Ok(0) => Err(io::ErrorKind::WouldBlock.into()),
51-
other => return other,
50+
other => other,
5251
}
5352
}
5453

@@ -81,10 +80,32 @@ where
8180
{
8281
let config = ClientConfig::builder().with_safe_defaults();
8382

83+
// authentication using user's key and its associated certificate
84+
let user_auth = match (tls_config.client_cert_path, tls_config.client_key_path) {
85+
(Some(cert_path), Some(key_path)) => {
86+
let cert_chain = certs_from_pem(cert_path.data().await?)?;
87+
let key_der = private_key_from_pem(key_path.data().await?)?;
88+
Some((cert_chain, key_der))
89+
}
90+
(None, None) => None,
91+
(_, _) => {
92+
return Err(Error::Configuration(
93+
"user auth key and certs must be given together".into(),
94+
))
95+
}
96+
};
97+
8498
let config = if tls_config.accept_invalid_certs {
85-
config
86-
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
87-
.with_no_client_auth()
99+
if let Some(user_auth) = user_auth {
100+
config
101+
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
102+
.with_single_cert(user_auth.0, user_auth.1)
103+
.map_err(Error::tls)?
104+
} else {
105+
config
106+
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
107+
.with_no_client_auth()
108+
}
88109
} else {
89110
let mut cert_store = RootCertStore::empty();
90111
cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
@@ -100,37 +121,22 @@ where
100121
let mut cursor = Cursor::new(data);
101122

102123
for cert in rustls_pemfile::certs(&mut cursor)
103-
.map_err(|_| Error::Tls(format!("Invalid certificate {}", ca).into()))?
124+
.map_err(|_| Error::Tls(format!("Invalid certificate {ca}").into()))?
104125
{
105126
cert_store
106127
.add(&rustls::Certificate(cert))
107128
.map_err(|err| Error::Tls(err.into()))?;
108129
}
109130
}
110131

111-
// authentication using user's key and its associated certificate
112-
let user_auth = match (tls_config.client_cert_path, tls_config.client_key_path) {
113-
(Some(cert_path), Some(key_path)) => {
114-
let cert_chain = certs_from_pem(cert_path.data().await?)?;
115-
let key_der = private_key_from_pem(key_path.data().await?)?;
116-
Some((cert_chain, key_der))
117-
}
118-
(None, None) => None,
119-
(_, _) => {
120-
return Err(Error::Configuration(
121-
"user auth key and certs must be given together".into(),
122-
))
123-
}
124-
};
125-
126132
if tls_config.accept_invalid_hostnames {
127133
let verifier = WebPkiVerifier::new(cert_store, None);
128134

129135
if let Some(user_auth) = user_auth {
130136
config
131137
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
132138
.with_single_cert(user_auth.0, user_auth.1)
133-
.map_err(|err| Error::Tls(err.into()))?
139+
.map_err(Error::tls)?
134140
} else {
135141
config
136142
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
@@ -140,7 +146,7 @@ where
140146
config
141147
.with_root_certificates(cert_store)
142148
.with_single_cert(user_auth.0, user_auth.1)
143-
.map_err(|err| Error::Tls(err.into()))?
149+
.map_err(Error::tls)?
144150
} else {
145151
config
146152
.with_root_certificates(cert_store)
@@ -162,6 +168,31 @@ where
162168
Ok(socket)
163169
}
164170

171+
fn certs_from_pem(pem: Vec<u8>) -> Result<Vec<rustls::Certificate>, Error> {
172+
let cur = Cursor::new(pem);
173+
let mut reader = BufReader::new(cur);
174+
rustls_pemfile::certs(&mut reader)?
175+
.into_iter()
176+
.map(|v| Ok(rustls::Certificate(v)))
177+
.collect()
178+
}
179+
180+
fn private_key_from_pem(pem: Vec<u8>) -> Result<rustls::PrivateKey, Error> {
181+
let cur = Cursor::new(pem);
182+
let mut reader = BufReader::new(cur);
183+
184+
loop {
185+
match rustls_pemfile::read_one(&mut reader)? {
186+
Some(rustls_pemfile::Item::RSAKey(key)) => return Ok(rustls::PrivateKey(key)),
187+
Some(rustls_pemfile::Item::PKCS8Key(key)) => return Ok(rustls::PrivateKey(key)),
188+
None => break,
189+
_ => {}
190+
}
191+
}
192+
193+
Err(Error::Configuration("no keys found pem file".into()))
194+
}
195+
165196
struct DummyTlsVerifier;
166197

167198
impl ServerCertVerifier for DummyTlsVerifier {

0 commit comments

Comments
 (0)