Skip to content

Commit e03e46f

Browse files
committed
Add missing where clause to guarentee correct derivation of Send
1 parent 60833da commit e03e46f

File tree

2 files changed

+368
-1
lines changed

2 files changed

+368
-1
lines changed

mbedtls/src/wrapper_macros.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ macro_rules! define_struct {
179179
);
180180

181181
as_item!(
182-
unsafe impl<$($g)*> Send for $name<$($g)*> {}
182+
unsafe impl<$($g)*> Send for $name<$($g)*>
183+
where $($g: Send)*
184+
{}
183185
);
184186
};
185187

mbedtls/test.rs

+365
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
pub mod context {
2+
use core::any::Any;
3+
use core::result::Result as StdResult;
4+
#[cfg(feature = "std")]
5+
use std::io::{Read, Write, Result as IoResult};
6+
#[cfg(feature = "std")]
7+
use std::sync::Arc;
8+
use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void};
9+
use mbedtls_sys::types::size_t;
10+
use mbedtls_sys::*;
11+
use crate::alloc::{List as MbedtlsList};
12+
use crate::error::{Error, Result, IntoResult};
13+
use crate::pk::Pk;
14+
use crate::private::UnsafeFrom;
15+
use crate::ssl::config::{Config, Version, AuthMode};
16+
use crate::x509::{Certificate, Crl, VerifyError};
17+
pub trait IoCallback: Any {
18+
unsafe extern "C" fn call_recv(
19+
user_data: *mut c_void,
20+
data: *mut c_uchar,
21+
len: size_t,
22+
) -> c_int
23+
where
24+
Self: Sized;
25+
unsafe extern "C" fn call_send(
26+
user_data: *mut c_void,
27+
data: *const c_uchar,
28+
len: size_t,
29+
) -> c_int
30+
where
31+
Self: Sized;
32+
fn data_ptr(&mut self) -> *mut c_void;
33+
}
34+
impl<IO: Read + Write + 'static> IoCallback for IO {
35+
unsafe extern "C" fn call_recv(
36+
user_data: *mut c_void,
37+
data: *mut c_uchar,
38+
len: size_t,
39+
) -> c_int {
40+
let len = if len > (c_int::max_value() as size_t) {
41+
c_int::max_value() as size_t
42+
} else {
43+
len
44+
};
45+
match (&mut *(user_data as *mut IO)).read(::core::slice::from_raw_parts_mut(data, len))
46+
{
47+
Ok(i) => i as c_int,
48+
Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED,
49+
}
50+
}
51+
unsafe extern "C" fn call_send(
52+
user_data: *mut c_void,
53+
data: *const c_uchar,
54+
len: size_t,
55+
) -> c_int {
56+
let len = if len > (c_int::max_value() as size_t) {
57+
c_int::max_value() as size_t
58+
} else {
59+
len
60+
};
61+
match (&mut *(user_data as *mut IO)).write(::core::slice::from_raw_parts(data, len)) {
62+
Ok(i) => i as c_int,
63+
Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED,
64+
}
65+
}
66+
fn data_ptr(&mut self) -> *mut c_void {
67+
self as *mut IO as *mut _
68+
}
69+
}
70+
#[allow(dead_code)]
71+
#[repr(C)]
72+
pub struct Context<S> {
73+
inner: ::mbedtls_sys::ssl_context,
74+
config: Arc<Config>,
75+
io: Option<Box<S>>,
76+
handshake_ca_cert: Option<Arc<MbedtlsList<Certificate>>>,
77+
handshake_crl: Option<Arc<Crl>>,
78+
handshake_cert: Vec<Arc<MbedtlsList<Certificate>>>,
79+
handshake_pk: Vec<Arc<Pk>>,
80+
}
81+
#[allow(dead_code)]
82+
impl<S> Context<S> {
83+
pub(crate) fn into_inner(self) -> ::mbedtls_sys::ssl_context {
84+
let inner = self.inner;
85+
::core::mem::forget(self);
86+
inner
87+
}
88+
pub(crate) fn handle(&self) -> &::mbedtls_sys::ssl_context {
89+
&self.inner
90+
}
91+
pub(crate) fn handle_mut(&mut self) -> &mut ::mbedtls_sys::ssl_context {
92+
&mut self.inner
93+
}
94+
}
95+
unsafe impl<S> Send for Context<S> where S: Send {}
96+
impl<'a, S> Into<*const ssl_context> for &'a Context<S> {
97+
fn into(self) -> *const ssl_context {
98+
self.handle()
99+
}
100+
}
101+
impl<'a, S> Into<*mut ssl_context> for &'a mut Context<S> {
102+
fn into(self) -> *mut ssl_context {
103+
self.handle_mut()
104+
}
105+
}
106+
impl<S> Context<S> {
107+
#[doc = r" Needed for compatibility with mbedtls - where we could pass"]
108+
#[doc = r" `*const` but function signature requires `*mut`"]
109+
#[allow(dead_code)]
110+
pub(crate) unsafe fn inner_ffi_mut(&self) -> *mut ssl_context {
111+
self.handle() as *const _ as *mut ssl_context
112+
}
113+
}
114+
impl<'a, S> crate::private::UnsafeFrom<*const ssl_context> for &'a Context<S> {
115+
unsafe fn from(ptr: *const ssl_context) -> Option<Self> {
116+
(ptr as *const Context<S>).as_ref()
117+
}
118+
}
119+
impl<'a, S> crate::private::UnsafeFrom<*mut ssl_context> for &'a mut Context<S> {
120+
unsafe fn from(ptr: *mut ssl_context) -> Option<Self> {
121+
(ptr as *mut Context<S>).as_mut()
122+
}
123+
}
124+
impl<S: IoCallback> Context<S> {
125+
pub fn establish(&mut self, io: S, hostname: Option<&str>) -> Result<()> {
126+
unsafe {
127+
let mut io = Box::new(io);
128+
ssl_session_reset(self.into()).into_result()?;
129+
self.set_hostname(hostname)?;
130+
let ptr = &mut *io as *mut _ as *mut c_void;
131+
ssl_set_bio(
132+
self.into(),
133+
ptr,
134+
Some(S::call_send),
135+
Some(S::call_recv),
136+
None,
137+
);
138+
self.io = Some(io);
139+
self.handshake_cert.clear();
140+
self.handshake_pk.clear();
141+
self.handshake_ca_cert = None;
142+
self.handshake_crl = None;
143+
match ssl_handshake(self.into()).into_result() {
144+
Err(e) => {
145+
ssl_set_bio(self.into(), ::core::ptr::null_mut(), None, None, None);
146+
self.io = None;
147+
Err(e)
148+
}
149+
Ok(_) => Ok(()),
150+
}
151+
}
152+
}
153+
}
154+
impl<S> Context<S> {
155+
pub fn new(config: Arc<Config>) -> Self {
156+
let mut inner = ssl_context::default();
157+
unsafe {
158+
ssl_init(&mut inner);
159+
ssl_setup(&mut inner, (&*config).into());
160+
};
161+
Context {
162+
inner,
163+
config: config.clone(),
164+
io: None,
165+
handshake_ca_cert: None,
166+
handshake_crl: None,
167+
handshake_cert: ::alloc::vec::Vec::new(),
168+
handshake_pk: ::alloc::vec::Vec::new(),
169+
}
170+
}
171+
#[cfg(feature = "std")]
172+
fn set_hostname(&mut self, hostname: Option<&str>) -> Result<()> {
173+
if let Some(s) = hostname {
174+
let cstr = ::std::ffi::CString::new(s).map_err(|_| Error::SslBadInputData)?;
175+
unsafe {
176+
ssl_set_hostname(self.into(), cstr.as_ptr())
177+
.into_result()
178+
.map(|_| ())
179+
}
180+
} else {
181+
Ok(())
182+
}
183+
}
184+
pub fn verify_result(&self) -> StdResult<(), VerifyError> {
185+
match unsafe { ssl_get_verify_result(self.into()) } {
186+
0 => Ok(()),
187+
flags => Err(VerifyError::from_bits_truncate(flags)),
188+
}
189+
}
190+
pub fn config(&self) -> &Arc<Config> {
191+
&self.config
192+
}
193+
pub fn close(&mut self) {
194+
unsafe {
195+
ssl_close_notify(self.into());
196+
ssl_set_bio(self.into(), ::core::ptr::null_mut(), None, None, None);
197+
self.io = None;
198+
}
199+
}
200+
pub fn io(&self) -> Option<&Box<S>> {
201+
self.io.as_ref()
202+
}
203+
pub fn io_mut(&mut self) -> Option<&mut Box<S>> {
204+
self.io.as_mut()
205+
}
206+
#[doc = " Return the minor number of the negotiated TLS version"]
207+
pub fn minor_version(&self) -> i32 {
208+
self.inner.minor_ver
209+
}
210+
#[doc = " Return the major number of the negotiated TLS version"]
211+
pub fn major_version(&self) -> i32 {
212+
self.inner.major_ver
213+
}
214+
#[doc = " Return the number of bytes currently available to read that"]
215+
#[doc = " are stored in the Session's internal read buffer"]
216+
pub fn bytes_available(&self) -> usize {
217+
unsafe { ssl_get_bytes_avail(self.into()) }
218+
}
219+
pub fn version(&self) -> Version {
220+
let major = self.major_version();
221+
{
222+
match (&major, &3) {
223+
(left_val, right_val) => {
224+
if !(*left_val == *right_val) {
225+
let kind = ::core::panicking::AssertKind::Eq;
226+
::core::panicking::assert_failed(
227+
kind,
228+
&*left_val,
229+
&*right_val,
230+
::core::option::Option::None,
231+
);
232+
}
233+
}
234+
}
235+
};
236+
let minor = self.minor_version();
237+
match minor {
238+
0 => Version::Ssl3,
239+
1 => Version::Tls1_0,
240+
2 => Version::Tls1_1,
241+
3 => Version::Tls1_2,
242+
_ => ::core::panicking::panic_fmt(::core::fmt::Arguments::new_v1(
243+
&["internal error: entered unreachable code: "],
244+
&match (&"unexpected TLS version",) {
245+
(arg0,) => [::core::fmt::ArgumentV1::new(
246+
arg0,
247+
::core::fmt::Display::fmt,
248+
)],
249+
},
250+
)),
251+
}
252+
}
253+
#[doc = " Return the 16-bit ciphersuite identifier."]
254+
#[doc = " All assigned ciphersuites are listed by the IANA in"]
255+
#[doc = " https://www.iana.org/assignments/tls-parameters/tls-parameters.txt"]
256+
pub fn ciphersuite(&self) -> Result<u16> {
257+
if self.inner.session.is_null() {
258+
return Err(Error::SslBadInputData);
259+
}
260+
Ok(unsafe { self.inner.session.as_ref().unwrap().ciphersuite as u16 })
261+
}
262+
pub fn peer_cert(&self) -> Result<Option<&MbedtlsList<Certificate>>> {
263+
if self.inner.session.is_null() {
264+
return Err(Error::SslBadInputData);
265+
}
266+
unsafe {
267+
let peer_cert: &MbedtlsList<Certificate> = UnsafeFrom::from(
268+
&((*self.inner.session).peer_cert) as *const *mut x509_crt
269+
as *const *const x509_crt,
270+
)
271+
.ok_or(Error::SslBadInputData)?;
272+
Ok(Some(peer_cert))
273+
}
274+
}
275+
}
276+
impl<S> Drop for Context<S> {
277+
fn drop(&mut self) {
278+
unsafe {
279+
self.close();
280+
ssl_free(self.into());
281+
}
282+
}
283+
}
284+
impl<S> Read for Context<S> {
285+
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
286+
match unsafe { ssl_read(self.into(), buf.as_mut_ptr(), buf.len()).into_result() } {
287+
Err(Error::SslPeerCloseNotify) => Ok(0),
288+
Err(e) => Err(crate::private::error_to_io_error(e)),
289+
Ok(i) => Ok(i as usize),
290+
}
291+
}
292+
}
293+
impl<S> Write for Context<S> {
294+
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
295+
match unsafe { ssl_write(self.into(), buf.as_ptr(), buf.len()).into_result() } {
296+
Err(Error::SslPeerCloseNotify) => Ok(0),
297+
Err(e) => Err(crate::private::error_to_io_error(e)),
298+
Ok(i) => Ok(i as usize),
299+
}
300+
}
301+
fn flush(&mut self) -> IoResult<()> {
302+
Ok(())
303+
}
304+
}
305+
pub struct HandshakeContext<'ctx> {
306+
pub context: &'ctx mut Context<Box<dyn Any>>,
307+
}
308+
impl<'ctx> HandshakeContext<'ctx> {
309+
pub(crate) fn init(context: &'ctx mut Context<Box<dyn Any>>) -> Self {
310+
HandshakeContext { context }
311+
}
312+
pub fn set_authmode(&mut self, am: AuthMode) -> Result<()> {
313+
if self.context.inner.handshake as *const _ == ::core::ptr::null() {
314+
return Err(Error::SslBadInputData);
315+
}
316+
unsafe { ssl_set_hs_authmode(self.context.into(), am as i32) }
317+
Ok(())
318+
}
319+
pub fn set_ca_list(
320+
&mut self,
321+
chain: Arc<MbedtlsList<Certificate>>,
322+
crl: Option<Arc<Crl>>,
323+
) -> Result<()> {
324+
if self.context.inner.handshake as *const _ == ::core::ptr::null() {
325+
return Err(Error::SslBadInputData);
326+
}
327+
unsafe {
328+
ssl_set_hs_ca_chain(
329+
self.context.into(),
330+
chain.inner_ffi_mut(),
331+
crl.as_ref()
332+
.map(|crl| crl.inner_ffi_mut())
333+
.unwrap_or(::core::ptr::null_mut()),
334+
);
335+
}
336+
self.context.handshake_ca_cert = Some(chain);
337+
self.context.handshake_crl = crl;
338+
Ok(())
339+
}
340+
#[doc = " If this is never called, will use the set of private keys and"]
341+
#[doc = " certificates configured in the `Config` associated with this `Context`."]
342+
#[doc = " If this is called at least once, all those are ignored and the set"]
343+
#[doc = " specified using this function is used."]
344+
pub fn push_cert(
345+
&mut self,
346+
chain: Arc<MbedtlsList<Certificate>>,
347+
key: Arc<Pk>,
348+
) -> Result<()> {
349+
if self.context.inner.handshake as *const _ == ::core::ptr::null() {
350+
return Err(Error::SslBadInputData);
351+
}
352+
unsafe {
353+
ssl_set_hs_own_cert(
354+
self.context.into(),
355+
chain.inner_ffi_mut(),
356+
key.inner_ffi_mut(),
357+
)
358+
.into_result()?;
359+
}
360+
self.context.handshake_cert.push(chain);
361+
self.context.handshake_pk.push(key);
362+
Ok(())
363+
}
364+
}
365+
}

0 commit comments

Comments
 (0)