diff --git a/src/base/name/label.rs b/src/base/name/label.rs index 9ec8e7d16..d334664e7 100644 --- a/src/base/name/label.rs +++ b/src/base/name/label.rs @@ -51,7 +51,7 @@ impl Label { /// # Safety /// /// The `slice` must be at most 63 octets long. - pub(super) unsafe fn from_slice_unchecked(slice: &[u8]) -> &Self { + pub(super) const unsafe fn from_slice_unchecked(slice: &[u8]) -> &Self { // SAFETY: Label has repr(transparent) mem::transmute(slice) } @@ -110,7 +110,7 @@ impl Label { /// /// On success, the function returns a label and the remainder of /// the slice. - pub fn split_from( + pub const fn split_from( slice: &[u8], ) -> Result<(&Self, &[u8]), SplitLabelError> { let head = match slice.first() { @@ -125,11 +125,11 @@ impl Label { )) } 0xC0..=0xFF => { - let res = match slice.get(1) { - Some(ch) => u16::from(*ch), - None => return Err(SplitLabelError::ShortInput), - }; - let res = res | ((u16::from(head) & 0x3F) << 8); + if slice.len() < 2 { + return Err(SplitLabelError::ShortInput); + } + let res = slice[1] as u16; + let res = res | (((head as u16) & 0x3F) << 8); return Err(SplitLabelError::Pointer(res)); } _ => { @@ -141,10 +141,10 @@ impl Label { if slice.len() < end { return Err(SplitLabelError::ShortInput); } - Ok(( - unsafe { Self::from_slice_unchecked(&slice[1..end]) }, - &slice[end..], - )) + + let (left, right) = slice.split_at(end); + let (_, label_data) = left.split_at(1); + Ok((unsafe { Self::from_slice_unchecked(label_data) }, right)) } /// Splits a mutable label from the beginning of an octets slice. @@ -211,7 +211,7 @@ impl Label { /// Returns a reference to the underlying octets slice. #[must_use] - pub fn as_slice(&self) -> &[u8] { + pub const fn as_slice(&self) -> &[u8] { &self.0 } @@ -320,13 +320,13 @@ impl Label { /// Returns whether this is the empty label. #[must_use] - pub fn is_empty(&self) -> bool { + pub const fn is_empty(&self) -> bool { self.as_slice().is_empty() } /// Returns whether the label is the root label. #[must_use] - pub fn is_root(&self) -> bool { + pub const fn is_root(&self) -> bool { self.is_empty() } diff --git a/src/base/name/relative.rs b/src/base/name/relative.rs index 7c5dbe900..973046c27 100644 --- a/src/base/name/relative.rs +++ b/src/base/name/relative.rs @@ -133,7 +133,7 @@ impl RelativeName<[u8]> { /// The same rules as for [`from_octets_unchecked`] apply. /// /// [`from_octets_unchecked`]: RelativeName::from_octets_unchecked - pub(super) unsafe fn from_slice_unchecked(slice: &[u8]) -> &Self { + pub(super) const unsafe fn from_slice_unchecked(slice: &[u8]) -> &Self { // SAFETY: RelativeName has repr(transparent) mem::transmute(slice) } @@ -148,9 +148,13 @@ impl RelativeName<[u8]> { /// use domain::base::name::RelativeName; /// RelativeName::from_slice(b"\x0c_submissions\x04_tcp"); /// ``` - pub fn from_slice(slice: &[u8]) -> Result<&Self, RelativeNameError> { - Self::check_slice(slice)?; - Ok(unsafe { Self::from_slice_unchecked(slice) }) + pub const fn from_slice( + slice: &[u8], + ) -> Result<&Self, RelativeNameError> { + match Self::check_slice(slice) { + Ok(()) => Ok(unsafe { Self::from_slice_unchecked(slice) }), + Err(err) => Err(err), + } } /// Returns an empty relative name atop a unsized slice. @@ -165,16 +169,33 @@ impl RelativeName<[u8]> { } /// Checks whether an octet slice contains a correctly encoded name. - pub(super) fn check_slice( + pub(super) const fn check_slice( mut slice: &[u8], ) -> Result<(), RelativeNameError> { if slice.len() > 254 { - return Err(RelativeNameErrorEnum::LongName.into()); + return Err(RelativeNameError(RelativeNameErrorEnum::LongName)); } while !slice.is_empty() { - let (label, tail) = Label::split_from(slice)?; + let (label, tail) = match Label::split_from(slice) { + Ok((label, tail)) => (label, tail), + Err(err) => { + return Err(RelativeNameError(match err { + SplitLabelError::Pointer(_) => { + RelativeNameErrorEnum::CompressedName + } + SplitLabelError::BadType(t) => { + RelativeNameErrorEnum::BadLabel(t) + } + SplitLabelError::ShortInput => { + RelativeNameErrorEnum::ShortInput + } + })); + } + }; if label.is_root() { - return Err(RelativeNameErrorEnum::AbsoluteName.into()); + return Err(RelativeNameError( + RelativeNameErrorEnum::AbsoluteName, + )); } slice = tail; } @@ -1773,6 +1794,31 @@ mod test { cmp(b"\x07example\x03com", "example.com"); } + const VALID_NAME: &RelativeName<[u8]> = + match RelativeName::from_slice(b"\x03www\x07example") { + Ok(name) => name, + Err(_) => panic!("VALID_NAME failed at compile time"), + }; + const EMPTY_NAME: &RelativeName<[u8]> = + match RelativeName::from_slice(b"") { + Ok(name) => name, + Err(_) => { + panic!("EMPTY_NAME failed at compile time") + } + }; + const INVALID_NAME: RelativeNameError = + match RelativeName::from_slice(b"\x03www\x07example\x03com\0") { + Ok(_) => panic!("INVALID_NAME succeeded at compile time"), + Err(err) => err, + }; + + #[test] + fn const_from_slice() { + assert_eq!(VALID_NAME.as_slice(), b"\x03www\x07example"); + assert_eq!(EMPTY_NAME.as_slice(), b""); + assert_eq!(INVALID_NAME, RelativeNameErrorEnum::AbsoluteName.into()); + } + #[cfg(all(feature = "serde", feature = "std"))] #[test] fn ser_de() {