|
| 1 | +pub mod context { |
| 2 | + use core::any::Any; |
| 3 | + use core::result::Result as StdResult; |
| 4 | + #[cfg(feature = "std")] |
| 5 | + use std::io::{Read, Write, Result as IoResult}; |
| 6 | + #[cfg(feature = "std")] |
| 7 | + use std::sync::Arc; |
| 8 | + use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void}; |
| 9 | + use mbedtls_sys::types::size_t; |
| 10 | + use mbedtls_sys::*; |
| 11 | + use crate::alloc::{List as MbedtlsList}; |
| 12 | + use crate::error::{Error, Result, IntoResult}; |
| 13 | + use crate::pk::Pk; |
| 14 | + use crate::private::UnsafeFrom; |
| 15 | + use crate::ssl::config::{Config, Version, AuthMode}; |
| 16 | + use crate::x509::{Certificate, Crl, VerifyError}; |
| 17 | + pub trait IoCallback: Any { |
| 18 | + unsafe extern "C" fn call_recv( |
| 19 | + user_data: *mut c_void, |
| 20 | + data: *mut c_uchar, |
| 21 | + len: size_t, |
| 22 | + ) -> c_int |
| 23 | + where |
| 24 | + Self: Sized; |
| 25 | + unsafe extern "C" fn call_send( |
| 26 | + user_data: *mut c_void, |
| 27 | + data: *const c_uchar, |
| 28 | + len: size_t, |
| 29 | + ) -> c_int |
| 30 | + where |
| 31 | + Self: Sized; |
| 32 | + fn data_ptr(&mut self) -> *mut c_void; |
| 33 | + } |
| 34 | + impl<IO: Read + Write + 'static> IoCallback for IO { |
| 35 | + unsafe extern "C" fn call_recv( |
| 36 | + user_data: *mut c_void, |
| 37 | + data: *mut c_uchar, |
| 38 | + len: size_t, |
| 39 | + ) -> c_int { |
| 40 | + let len = if len > (c_int::max_value() as size_t) { |
| 41 | + c_int::max_value() as size_t |
| 42 | + } else { |
| 43 | + len |
| 44 | + }; |
| 45 | + match (&mut *(user_data as *mut IO)).read(::core::slice::from_raw_parts_mut(data, len)) |
| 46 | + { |
| 47 | + Ok(i) => i as c_int, |
| 48 | + Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED, |
| 49 | + } |
| 50 | + } |
| 51 | + unsafe extern "C" fn call_send( |
| 52 | + user_data: *mut c_void, |
| 53 | + data: *const c_uchar, |
| 54 | + len: size_t, |
| 55 | + ) -> c_int { |
| 56 | + let len = if len > (c_int::max_value() as size_t) { |
| 57 | + c_int::max_value() as size_t |
| 58 | + } else { |
| 59 | + len |
| 60 | + }; |
| 61 | + match (&mut *(user_data as *mut IO)).write(::core::slice::from_raw_parts(data, len)) { |
| 62 | + Ok(i) => i as c_int, |
| 63 | + Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED, |
| 64 | + } |
| 65 | + } |
| 66 | + fn data_ptr(&mut self) -> *mut c_void { |
| 67 | + self as *mut IO as *mut _ |
| 68 | + } |
| 69 | + } |
| 70 | + #[allow(dead_code)] |
| 71 | + #[repr(C)] |
| 72 | + pub struct Context<S> { |
| 73 | + inner: ::mbedtls_sys::ssl_context, |
| 74 | + config: Arc<Config>, |
| 75 | + io: Option<Box<S>>, |
| 76 | + handshake_ca_cert: Option<Arc<MbedtlsList<Certificate>>>, |
| 77 | + handshake_crl: Option<Arc<Crl>>, |
| 78 | + handshake_cert: Vec<Arc<MbedtlsList<Certificate>>>, |
| 79 | + handshake_pk: Vec<Arc<Pk>>, |
| 80 | + } |
| 81 | + #[allow(dead_code)] |
| 82 | + impl<S> Context<S> { |
| 83 | + pub(crate) fn into_inner(self) -> ::mbedtls_sys::ssl_context { |
| 84 | + let inner = self.inner; |
| 85 | + ::core::mem::forget(self); |
| 86 | + inner |
| 87 | + } |
| 88 | + pub(crate) fn handle(&self) -> &::mbedtls_sys::ssl_context { |
| 89 | + &self.inner |
| 90 | + } |
| 91 | + pub(crate) fn handle_mut(&mut self) -> &mut ::mbedtls_sys::ssl_context { |
| 92 | + &mut self.inner |
| 93 | + } |
| 94 | + } |
| 95 | + unsafe impl<S> Send for Context<S> where S: Send {} |
| 96 | + impl<'a, S> Into<*const ssl_context> for &'a Context<S> { |
| 97 | + fn into(self) -> *const ssl_context { |
| 98 | + self.handle() |
| 99 | + } |
| 100 | + } |
| 101 | + impl<'a, S> Into<*mut ssl_context> for &'a mut Context<S> { |
| 102 | + fn into(self) -> *mut ssl_context { |
| 103 | + self.handle_mut() |
| 104 | + } |
| 105 | + } |
| 106 | + impl<S> Context<S> { |
| 107 | + #[doc = r" Needed for compatibility with mbedtls - where we could pass"] |
| 108 | + #[doc = r" `*const` but function signature requires `*mut`"] |
| 109 | + #[allow(dead_code)] |
| 110 | + pub(crate) unsafe fn inner_ffi_mut(&self) -> *mut ssl_context { |
| 111 | + self.handle() as *const _ as *mut ssl_context |
| 112 | + } |
| 113 | + } |
| 114 | + impl<'a, S> crate::private::UnsafeFrom<*const ssl_context> for &'a Context<S> { |
| 115 | + unsafe fn from(ptr: *const ssl_context) -> Option<Self> { |
| 116 | + (ptr as *const Context<S>).as_ref() |
| 117 | + } |
| 118 | + } |
| 119 | + impl<'a, S> crate::private::UnsafeFrom<*mut ssl_context> for &'a mut Context<S> { |
| 120 | + unsafe fn from(ptr: *mut ssl_context) -> Option<Self> { |
| 121 | + (ptr as *mut Context<S>).as_mut() |
| 122 | + } |
| 123 | + } |
| 124 | + impl<S: IoCallback> Context<S> { |
| 125 | + pub fn establish(&mut self, io: S, hostname: Option<&str>) -> Result<()> { |
| 126 | + unsafe { |
| 127 | + let mut io = Box::new(io); |
| 128 | + ssl_session_reset(self.into()).into_result()?; |
| 129 | + self.set_hostname(hostname)?; |
| 130 | + let ptr = &mut *io as *mut _ as *mut c_void; |
| 131 | + ssl_set_bio( |
| 132 | + self.into(), |
| 133 | + ptr, |
| 134 | + Some(S::call_send), |
| 135 | + Some(S::call_recv), |
| 136 | + None, |
| 137 | + ); |
| 138 | + self.io = Some(io); |
| 139 | + self.handshake_cert.clear(); |
| 140 | + self.handshake_pk.clear(); |
| 141 | + self.handshake_ca_cert = None; |
| 142 | + self.handshake_crl = None; |
| 143 | + match ssl_handshake(self.into()).into_result() { |
| 144 | + Err(e) => { |
| 145 | + ssl_set_bio(self.into(), ::core::ptr::null_mut(), None, None, None); |
| 146 | + self.io = None; |
| 147 | + Err(e) |
| 148 | + } |
| 149 | + Ok(_) => Ok(()), |
| 150 | + } |
| 151 | + } |
| 152 | + } |
| 153 | + } |
| 154 | + impl<S> Context<S> { |
| 155 | + pub fn new(config: Arc<Config>) -> Self { |
| 156 | + let mut inner = ssl_context::default(); |
| 157 | + unsafe { |
| 158 | + ssl_init(&mut inner); |
| 159 | + ssl_setup(&mut inner, (&*config).into()); |
| 160 | + }; |
| 161 | + Context { |
| 162 | + inner, |
| 163 | + config: config.clone(), |
| 164 | + io: None, |
| 165 | + handshake_ca_cert: None, |
| 166 | + handshake_crl: None, |
| 167 | + handshake_cert: ::alloc::vec::Vec::new(), |
| 168 | + handshake_pk: ::alloc::vec::Vec::new(), |
| 169 | + } |
| 170 | + } |
| 171 | + #[cfg(feature = "std")] |
| 172 | + fn set_hostname(&mut self, hostname: Option<&str>) -> Result<()> { |
| 173 | + if let Some(s) = hostname { |
| 174 | + let cstr = ::std::ffi::CString::new(s).map_err(|_| Error::SslBadInputData)?; |
| 175 | + unsafe { |
| 176 | + ssl_set_hostname(self.into(), cstr.as_ptr()) |
| 177 | + .into_result() |
| 178 | + .map(|_| ()) |
| 179 | + } |
| 180 | + } else { |
| 181 | + Ok(()) |
| 182 | + } |
| 183 | + } |
| 184 | + pub fn verify_result(&self) -> StdResult<(), VerifyError> { |
| 185 | + match unsafe { ssl_get_verify_result(self.into()) } { |
| 186 | + 0 => Ok(()), |
| 187 | + flags => Err(VerifyError::from_bits_truncate(flags)), |
| 188 | + } |
| 189 | + } |
| 190 | + pub fn config(&self) -> &Arc<Config> { |
| 191 | + &self.config |
| 192 | + } |
| 193 | + pub fn close(&mut self) { |
| 194 | + unsafe { |
| 195 | + ssl_close_notify(self.into()); |
| 196 | + ssl_set_bio(self.into(), ::core::ptr::null_mut(), None, None, None); |
| 197 | + self.io = None; |
| 198 | + } |
| 199 | + } |
| 200 | + pub fn io(&self) -> Option<&Box<S>> { |
| 201 | + self.io.as_ref() |
| 202 | + } |
| 203 | + pub fn io_mut(&mut self) -> Option<&mut Box<S>> { |
| 204 | + self.io.as_mut() |
| 205 | + } |
| 206 | + #[doc = " Return the minor number of the negotiated TLS version"] |
| 207 | + pub fn minor_version(&self) -> i32 { |
| 208 | + self.inner.minor_ver |
| 209 | + } |
| 210 | + #[doc = " Return the major number of the negotiated TLS version"] |
| 211 | + pub fn major_version(&self) -> i32 { |
| 212 | + self.inner.major_ver |
| 213 | + } |
| 214 | + #[doc = " Return the number of bytes currently available to read that"] |
| 215 | + #[doc = " are stored in the Session's internal read buffer"] |
| 216 | + pub fn bytes_available(&self) -> usize { |
| 217 | + unsafe { ssl_get_bytes_avail(self.into()) } |
| 218 | + } |
| 219 | + pub fn version(&self) -> Version { |
| 220 | + let major = self.major_version(); |
| 221 | + { |
| 222 | + match (&major, &3) { |
| 223 | + (left_val, right_val) => { |
| 224 | + if !(*left_val == *right_val) { |
| 225 | + let kind = ::core::panicking::AssertKind::Eq; |
| 226 | + ::core::panicking::assert_failed( |
| 227 | + kind, |
| 228 | + &*left_val, |
| 229 | + &*right_val, |
| 230 | + ::core::option::Option::None, |
| 231 | + ); |
| 232 | + } |
| 233 | + } |
| 234 | + } |
| 235 | + }; |
| 236 | + let minor = self.minor_version(); |
| 237 | + match minor { |
| 238 | + 0 => Version::Ssl3, |
| 239 | + 1 => Version::Tls1_0, |
| 240 | + 2 => Version::Tls1_1, |
| 241 | + 3 => Version::Tls1_2, |
| 242 | + _ => ::core::panicking::panic_fmt(::core::fmt::Arguments::new_v1( |
| 243 | + &["internal error: entered unreachable code: "], |
| 244 | + &match (&"unexpected TLS version",) { |
| 245 | + (arg0,) => [::core::fmt::ArgumentV1::new( |
| 246 | + arg0, |
| 247 | + ::core::fmt::Display::fmt, |
| 248 | + )], |
| 249 | + }, |
| 250 | + )), |
| 251 | + } |
| 252 | + } |
| 253 | + #[doc = " Return the 16-bit ciphersuite identifier."] |
| 254 | + #[doc = " All assigned ciphersuites are listed by the IANA in"] |
| 255 | + #[doc = " https://www.iana.org/assignments/tls-parameters/tls-parameters.txt"] |
| 256 | + pub fn ciphersuite(&self) -> Result<u16> { |
| 257 | + if self.inner.session.is_null() { |
| 258 | + return Err(Error::SslBadInputData); |
| 259 | + } |
| 260 | + Ok(unsafe { self.inner.session.as_ref().unwrap().ciphersuite as u16 }) |
| 261 | + } |
| 262 | + pub fn peer_cert(&self) -> Result<Option<&MbedtlsList<Certificate>>> { |
| 263 | + if self.inner.session.is_null() { |
| 264 | + return Err(Error::SslBadInputData); |
| 265 | + } |
| 266 | + unsafe { |
| 267 | + let peer_cert: &MbedtlsList<Certificate> = UnsafeFrom::from( |
| 268 | + &((*self.inner.session).peer_cert) as *const *mut x509_crt |
| 269 | + as *const *const x509_crt, |
| 270 | + ) |
| 271 | + .ok_or(Error::SslBadInputData)?; |
| 272 | + Ok(Some(peer_cert)) |
| 273 | + } |
| 274 | + } |
| 275 | + } |
| 276 | + impl<S> Drop for Context<S> { |
| 277 | + fn drop(&mut self) { |
| 278 | + unsafe { |
| 279 | + self.close(); |
| 280 | + ssl_free(self.into()); |
| 281 | + } |
| 282 | + } |
| 283 | + } |
| 284 | + impl<S> Read for Context<S> { |
| 285 | + fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> { |
| 286 | + match unsafe { ssl_read(self.into(), buf.as_mut_ptr(), buf.len()).into_result() } { |
| 287 | + Err(Error::SslPeerCloseNotify) => Ok(0), |
| 288 | + Err(e) => Err(crate::private::error_to_io_error(e)), |
| 289 | + Ok(i) => Ok(i as usize), |
| 290 | + } |
| 291 | + } |
| 292 | + } |
| 293 | + impl<S> Write for Context<S> { |
| 294 | + fn write(&mut self, buf: &[u8]) -> IoResult<usize> { |
| 295 | + match unsafe { ssl_write(self.into(), buf.as_ptr(), buf.len()).into_result() } { |
| 296 | + Err(Error::SslPeerCloseNotify) => Ok(0), |
| 297 | + Err(e) => Err(crate::private::error_to_io_error(e)), |
| 298 | + Ok(i) => Ok(i as usize), |
| 299 | + } |
| 300 | + } |
| 301 | + fn flush(&mut self) -> IoResult<()> { |
| 302 | + Ok(()) |
| 303 | + } |
| 304 | + } |
| 305 | + pub struct HandshakeContext<'ctx> { |
| 306 | + pub context: &'ctx mut Context<Box<dyn Any>>, |
| 307 | + } |
| 308 | + impl<'ctx> HandshakeContext<'ctx> { |
| 309 | + pub(crate) fn init(context: &'ctx mut Context<Box<dyn Any>>) -> Self { |
| 310 | + HandshakeContext { context } |
| 311 | + } |
| 312 | + pub fn set_authmode(&mut self, am: AuthMode) -> Result<()> { |
| 313 | + if self.context.inner.handshake as *const _ == ::core::ptr::null() { |
| 314 | + return Err(Error::SslBadInputData); |
| 315 | + } |
| 316 | + unsafe { ssl_set_hs_authmode(self.context.into(), am as i32) } |
| 317 | + Ok(()) |
| 318 | + } |
| 319 | + pub fn set_ca_list( |
| 320 | + &mut self, |
| 321 | + chain: Arc<MbedtlsList<Certificate>>, |
| 322 | + crl: Option<Arc<Crl>>, |
| 323 | + ) -> Result<()> { |
| 324 | + if self.context.inner.handshake as *const _ == ::core::ptr::null() { |
| 325 | + return Err(Error::SslBadInputData); |
| 326 | + } |
| 327 | + unsafe { |
| 328 | + ssl_set_hs_ca_chain( |
| 329 | + self.context.into(), |
| 330 | + chain.inner_ffi_mut(), |
| 331 | + crl.as_ref() |
| 332 | + .map(|crl| crl.inner_ffi_mut()) |
| 333 | + .unwrap_or(::core::ptr::null_mut()), |
| 334 | + ); |
| 335 | + } |
| 336 | + self.context.handshake_ca_cert = Some(chain); |
| 337 | + self.context.handshake_crl = crl; |
| 338 | + Ok(()) |
| 339 | + } |
| 340 | + #[doc = " If this is never called, will use the set of private keys and"] |
| 341 | + #[doc = " certificates configured in the `Config` associated with this `Context`."] |
| 342 | + #[doc = " If this is called at least once, all those are ignored and the set"] |
| 343 | + #[doc = " specified using this function is used."] |
| 344 | + pub fn push_cert( |
| 345 | + &mut self, |
| 346 | + chain: Arc<MbedtlsList<Certificate>>, |
| 347 | + key: Arc<Pk>, |
| 348 | + ) -> Result<()> { |
| 349 | + if self.context.inner.handshake as *const _ == ::core::ptr::null() { |
| 350 | + return Err(Error::SslBadInputData); |
| 351 | + } |
| 352 | + unsafe { |
| 353 | + ssl_set_hs_own_cert( |
| 354 | + self.context.into(), |
| 355 | + chain.inner_ffi_mut(), |
| 356 | + key.inner_ffi_mut(), |
| 357 | + ) |
| 358 | + .into_result()?; |
| 359 | + } |
| 360 | + self.context.handshake_cert.push(chain); |
| 361 | + self.context.handshake_pk.push(key); |
| 362 | + Ok(()) |
| 363 | + } |
| 364 | + } |
| 365 | +} |
0 commit comments