Skip to content

Commit

Permalink
fix(bindings): remove mutation behind Arc
Browse files Browse the repository at this point in the history
  • Loading branch information
jmayclin committed Feb 18, 2025
1 parent e4a5a74 commit 88ed139
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 44 deletions.
65 changes: 27 additions & 38 deletions bindings/rust/extended/s2n-tls/src/cert_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,38 @@ use std::{
///
/// [CertificateChain] is internally reference counted. The reference counted `T`
/// must have a drop implementation.
struct CertificateChainHandle {
cert: NonNull<s2n_cert_chain_and_key>,
pub(crate) struct CertificateChainHandle<'a> {
pub(crate) cert: NonNull<s2n_cert_chain_and_key>,
is_owned: bool,
_lifetime: PhantomData<&'a s2n_cert_chain_and_key>,
}

// # Safety
//
// s2n_cert_chain_and_key objects can be sent across threads.
unsafe impl Send for CertificateChainHandle {}
unsafe impl Sync for CertificateChainHandle {}
unsafe impl Send for CertificateChainHandle<'_> {}
unsafe impl Sync for CertificateChainHandle<'_> {}

impl CertificateChainHandle {
fn from_owned(cert: NonNull<s2n_cert_chain_and_key>) -> Self {
Self {
cert,
impl CertificateChainHandle<'_> {
pub(crate) fn allocate() -> Result<CertificateChainHandle<'static>, crate::error::Error> {
crate::init::init();
Ok(CertificateChainHandle {
cert: unsafe { s2n_cert_chain_and_key_new().into_result() }?,
is_owned: true,
}
_lifetime: PhantomData,
})
}

fn from_reference(cert: NonNull<s2n_cert_chain_and_key>) -> Self {
Self {
cert,
is_owned: false,
_lifetime: PhantomData,
}
}
}

impl Drop for CertificateChainHandle {
impl Drop for CertificateChainHandle<'_> {
/// Corresponds to [s2n_cert_chain_and_key_free].
fn drop(&mut self) {
// ignore failures since there's not much we can do about it
Expand All @@ -53,13 +57,13 @@ impl Drop for CertificateChainHandle {
}

pub struct Builder {
cert: CertificateChain<'static>,
cert: CertificateChainHandle<'static>,
}

impl Builder {
pub fn new() -> Result<Self, Error> {
Ok(Self {
cert: CertificateChain::allocate_owned()?,
cert: CertificateChainHandle::allocate()?,
})
}

Expand All @@ -73,7 +77,7 @@ impl Builder {
// `private_key_pem` are not modified.
// https://github.com/aws/s2n-tls/issues/4140
s2n_cert_chain_and_key_load_pem_bytes(
self.cert.as_mut_ptr(),
self.cert.cert.as_ptr(),
chain.as_ptr() as *mut _,
chain.len() as u32,
key.as_ptr() as *mut _,
Expand All @@ -95,7 +99,7 @@ impl Builder {
// is not modified
// https://github.com/aws/s2n-tls/issues/4140
s2n_cert_chain_and_key_load_public_pem_bytes(
self.cert.as_mut_ptr(),
self.cert.cert.as_ptr(),
chain.as_ptr() as *mut _,
chain.len() as u32,
)
Expand All @@ -109,7 +113,7 @@ impl Builder {
pub fn set_ocsp_data(&mut self, data: &[u8]) -> Result<&mut Self, Error> {
unsafe {
s2n_cert_chain_and_key_set_ocsp_data(
self.cert.as_mut_ptr(),
self.cert.cert.as_ptr(),
data.as_ptr(),
data.len() as u32,
)
Expand All @@ -122,7 +126,7 @@ impl Builder {
pub fn build(self) -> Result<CertificateChain<'static>, Error> {
// This method is currently infallible, but returning a result allows
// us to add validation in the future.
Ok(self.cert)
Ok(CertificateChain::from_allocated(self.cert))
}
}

Expand All @@ -135,22 +139,18 @@ impl Builder {
// safe to mutate CertificateChains.
#[derive(Clone)]
pub struct CertificateChain<'a> {
ptr: Arc<CertificateChainHandle>,
_lifetime: PhantomData<&'a s2n_cert_chain_and_key>,
ptr: Arc<CertificateChainHandle<'a>>,
}

impl CertificateChain<'_> {
/// This allocates a new certificate chain from s2n.
///
/// Corresponds to [s2n_cert_chain_and_key_new].
pub(crate) fn allocate_owned() -> Result<CertificateChain<'static>, Error> {
crate::init::init();
unsafe {
let ptr = s2n_cert_chain_and_key_new().into_result()?;
Ok(CertificateChain {
ptr: Arc::new(CertificateChainHandle::from_owned(ptr)),
_lifetime: PhantomData,
})
pub(crate) fn from_allocated(
handle: CertificateChainHandle<'static>,
) -> CertificateChain<'static> {
CertificateChain {
ptr: Arc::new(handle),
}
}

Expand All @@ -161,10 +161,7 @@ impl CertificateChain<'_> {
) -> CertificateChain<'a> {
let handle = Arc::new(CertificateChainHandle::from_reference(ptr));

CertificateChain {
ptr: handle,
_lifetime: PhantomData,
}
CertificateChain { ptr: handle }
}

pub fn iter(&self) -> CertificateChainIter<'_> {
Expand Down Expand Up @@ -202,14 +199,6 @@ impl CertificateChain<'_> {
self.len() == 0
}

/// SAFETY: Only one instance of `CertificateChain` may exist when this method
/// is called. s2n_cert_chain_and_key is not thread-safe, so it is not safe
/// to mutate the certificate chain if references are held across multiple threads.
pub(crate) unsafe fn as_mut_ptr(&mut self) -> *mut s2n_cert_chain_and_key {
debug_assert_eq!(Arc::strong_count(&self.ptr), 1);
self.ptr.cert.as_ptr()
}

pub(crate) fn as_ptr(&self) -> *const s2n_cert_chain_and_key {
self.ptr.cert.as_ptr() as *const _
}
Expand Down
15 changes: 9 additions & 6 deletions bindings/rust/extended/s2n-tls/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use crate::renegotiate::RenegotiateState;
use crate::{
callbacks::*,
cert_chain::CertificateChain,
cert_chain::{CertificateChain, CertificateChainHandle},
config::Config,
enums::*,
error::{Error, Fallible, Pollable},
Expand Down Expand Up @@ -1219,11 +1219,14 @@ impl Connection {
/// Corresponds to [s2n_connection_get_peer_cert_chain].
pub fn peer_cert_chain(&self) -> Result<CertificateChain<'static>, Error> {
unsafe {
let mut chain = CertificateChain::allocate_owned()?;
s2n_connection_get_peer_cert_chain(self.connection.as_ptr(), chain.as_mut_ptr())
.into_result()
.map(|_| ())?;
Ok(chain)
let chain_handle = CertificateChainHandle::allocate()?;
s2n_connection_get_peer_cert_chain(
self.connection.as_ptr(),
chain_handle.cert.as_ptr(),
)
.into_result()
.map(|_| ())?;
Ok(CertificateChain::from_allocated(chain_handle))
}
}

Expand Down

0 comments on commit 88ed139

Please sign in to comment.