diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d3ab3cb5..a2140ace 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,7 +3,7 @@ name: CI on: [push, pull_request] env: - minrust: 1.41.0 + minrust: 1.46.0 jobs: test: diff --git a/src/common/access_control_allow_origin.rs b/src/common/access_control_allow_origin.rs index c048bf83..8d38bf17 100644 --- a/src/common/access_control_allow_origin.rs +++ b/src/common/access_control_allow_origin.rs @@ -1,3 +1,5 @@ +use std::convert::TryFrom; + use super::origin::Origin; use util::{IterExt, TryFromValues}; use HeaderValue; @@ -25,9 +27,11 @@ use HeaderValue; /// ``` /// # extern crate headers; /// use headers::AccessControlAllowOrigin; +/// use std::convert::TryFrom; /// /// let any_origin = AccessControlAllowOrigin::ANY; /// let null_origin = AccessControlAllowOrigin::NULL; +/// let origin = AccessControlAllowOrigin::try_from("http://web-platform.test:8000"); /// ``` #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct AccessControlAllowOrigin(OriginOrAny); @@ -60,6 +64,26 @@ impl AccessControlAllowOrigin { } } +impl TryFrom<&str> for AccessControlAllowOrigin { + type Error = ::Error; + + fn try_from(s: &str) -> Result { + let header_value = HeaderValue::from_str(s).map_err(|_| ::Error::invalid())?; + let origin = OriginOrAny::try_from(&header_value)?; + Ok(Self(origin)) + } +} + +impl TryFrom<&HeaderValue> for OriginOrAny { + type Error = ::Error; + + fn try_from(header_value: &HeaderValue) -> Result { + Origin::try_from_value(header_value) + .map(OriginOrAny::Origin) + .ok_or_else(::Error::invalid) + } +} + impl TryFromValues for OriginOrAny { fn try_from_values<'i, I>(values: &mut I) -> Result where @@ -89,12 +113,14 @@ impl<'a> From<&'a OriginOrAny> for HeaderValue { #[cfg(test)] mod tests { + use super::super::{test_decode, test_encode}; use super::*; #[test] fn origin() { let s = "http://web-platform.test:8000"; + let allow_origin = test_decode::(&[s]).unwrap(); { let origin = allow_origin.origin().unwrap(); @@ -107,6 +133,22 @@ mod tests { assert_eq!(headers["access-control-allow-origin"], s); } + #[test] + fn try_from_origin() { + let s = "http://web-platform.test:8000"; + + let allow_origin = AccessControlAllowOrigin::try_from(s).unwrap(); + { + let origin = allow_origin.origin().unwrap(); + assert_eq!(origin.scheme(), "http"); + assert_eq!(origin.hostname(), "web-platform.test"); + assert_eq!(origin.port(), Some(8000)); + } + + let headers = test_encode(allow_origin); + assert_eq!(headers["access-control-allow-origin"], s); + } + #[test] fn any() { let allow_origin = test_decode::(&["*"]).unwrap();