From b0bf49137bde7066747d71a3d3fb6dbc46556ff8 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 6 Dec 2024 16:46:35 +0100 Subject: [PATCH 001/167] Add module 'new_base' --- src/lib.rs | 1 + src/new_base/mod.rs | 5 +++++ 2 files changed, 6 insertions(+) create mode 100644 src/new_base/mod.rs diff --git a/src/lib.rs b/src/lib.rs index 0d0a4a2ba..e9aef12b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -193,6 +193,7 @@ extern crate core; pub mod base; pub mod dep; pub mod net; +pub mod new_base; pub mod rdata; pub mod resolv; pub mod sign; diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs new file mode 100644 index 000000000..4257c2712 --- /dev/null +++ b/src/new_base/mod.rs @@ -0,0 +1,5 @@ +//! Basic DNS. +//! +//! This module provides the essential types and functionality for working +//! with DNS. Most importantly, it provides functionality for parsing and +//! building DNS messages on the wire. From ea600fd3f5e666c2812a29c15011f77190b76042 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 9 Dec 2024 17:36:41 +0100 Subject: [PATCH 002/167] [new_base] Add module 'name' --- src/new_base/mod.rs | 2 ++ src/new_base/name/mod.rs | 15 +++++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 src/new_base/name/mod.rs diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 4257c2712..c29a0b49f 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -3,3 +3,5 @@ //! This module provides the essential types and functionality for working //! with DNS. Most importantly, it provides functionality for parsing and //! building DNS messages on the wire. + +pub mod name; diff --git a/src/new_base/name/mod.rs b/src/new_base/name/mod.rs new file mode 100644 index 000000000..e288f9ee4 --- /dev/null +++ b/src/new_base/name/mod.rs @@ -0,0 +1,15 @@ +//! Domain names. +//! +//! Domain names are a core concept of DNS. The whole system is essentially +//! just a mapping from domain names to arbitrary information. This module +//! provides types and essential functionality for working with them. +//! +//! A domain name is a sequence of labels, separated by ASCII periods (`.`). +//! For example, `example.org.` contains three labels: `example`, `org`, and +//! `` (the root label). Outside DNS-specific code, the root label (and its +//! separator) are almost always omitted, but keep them in mind here. +//! +//! Domain names form a hierarchy, where `b.a` is the "parent" of `.c.b.a`. +//! The owner of `example.org` is thus responsible for _every_ domain ending +//! with the `.example.org` suffix. The reverse order in which this hierarchy +//! is expressed can sometimes be confusing. From 6fb0957e79691dc30608d6204640e3b2c8bc8f53 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 9 Dec 2024 20:01:52 +0100 Subject: [PATCH 003/167] [new_base/name] Define labels --- src/new_base/name/label.rs | 181 +++++++++++++++++++++++++++++++++++++ src/new_base/name/mod.rs | 3 + 2 files changed, 184 insertions(+) create mode 100644 src/new_base/name/label.rs diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs new file mode 100644 index 000000000..d7d83b6a1 --- /dev/null +++ b/src/new_base/name/label.rs @@ -0,0 +1,181 @@ +//! Labels in domain names. + +//----------- Label ---------------------------------------------------------- + +use core::{ + cmp::Ordering, + fmt, + hash::{Hash, Hasher}, +}; + +/// A label in a domain name. +/// +/// A label contains up to 63 bytes of arbitrary data. +#[repr(transparent)] +pub struct Label([u8]); + +//--- Associated Constants + +impl Label { + /// The root label. + pub const ROOT: &'static Self = { + // SAFETY: All slices of 63 bytes or less are valid. + unsafe { Self::from_bytes_unchecked(b"") } + }; + + /// The wildcard label. + pub const WILDCARD: &'static Self = { + // SAFETY: All slices of 63 bytes or less are valid. + unsafe { Self::from_bytes_unchecked(b"*") } + }; +} + +//--- Construction + +impl Label { + /// Assume a byte slice is a valid label. + /// + /// # Safety + /// + /// The byte slice must have length 63 or less. + pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Self { + // SAFETY: 'Label' is 'repr(transparent)' to '[u8]'. + unsafe { core::mem::transmute(bytes) } + } +} + +//--- Inspection + +impl Label { + /// The length of this label, in bytes. + #[allow(clippy::len_without_is_empty)] + pub const fn len(&self) -> usize { + self.0.len() + } + + /// Whether this is the root label. + pub const fn is_root(&self) -> bool { + self.0.is_empty() + } + + /// Whether this is a wildcard label. + pub const fn is_wildcard(&self) -> bool { + // NOTE: '==' for byte slices is not 'const'. + self.0.len() == 1 && self.0[0] == b'*' + } + + /// The bytes making up this label. + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } +} + +//--- Access to the underlying bytes + +impl AsRef<[u8]> for Label { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl<'a> From<&'a Label> for &'a [u8] { + fn from(value: &'a Label) -> Self { + &value.0 + } +} + +//--- Comparison + +impl PartialEq for Label { + /// Compare two labels for equality. + /// + /// Labels are compared ASCII-case-insensitively. + fn eq(&self, other: &Self) -> bool { + let this = self.as_bytes().iter().map(u8::to_ascii_lowercase); + let that = other.as_bytes().iter().map(u8::to_ascii_lowercase); + this.eq(that) + } +} + +impl Eq for Label {} + +//--- Ordering + +impl PartialOrd for Label { + /// Determine the order between labels. + /// + /// Any uppercase ASCII characters in the labels are treated as if they + /// were lowercase. The first unequal byte between two labels determines + /// its ordering: the label with the smaller byte value is the lesser. If + /// two labels have all the same bytes, the shorter label is lesser; if + /// they are the same length, they are equal. + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Label { + /// Determine the order between labels. + /// + /// Any uppercase ASCII characters in the labels are treated as if they + /// were lowercase. The first unequal byte between two labels determines + /// its ordering: the label with the smaller byte value is the lesser. If + /// two labels have all the same bytes, the shorter label is lesser; if + /// they are the same length, they are equal. + fn cmp(&self, other: &Self) -> Ordering { + let this = self.as_bytes().iter().map(u8::to_ascii_lowercase); + let that = other.as_bytes().iter().map(u8::to_ascii_lowercase); + this.cmp(that) + } +} + +//--- Hashing + +impl Hash for Label { + /// Hash this label. + /// + /// All uppercase ASCII characters are lowercased beforehand. This way, + /// the hash of a label is case-independent, consistent with how labels + /// are compared and ordered. + /// + /// The label is hashed as if it were a name containing a single label -- + /// the length octet is thus included. This makes the hashing consistent + /// between names and tuples (not slices!) of labels. + fn hash(&self, state: &mut H) { + state.write_u8(self.len() as u8); + for &byte in self.as_bytes() { + state.write_u8(byte.to_ascii_lowercase()) + } + } +} + +//--- Formatting + +impl fmt::Display for Label { + /// Print a label. + /// + /// The label is printed in the conventional zone file format, with bytes + /// outside printable ASCII formatted as `\\DDD` (a backslash followed by + /// three zero-padded decimal digits), and `.` and `\\` simply escaped by + /// a backslash. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.as_bytes().iter().try_for_each(|&byte| { + if b".\\".contains(&byte) { + write!(f, "\\{}", byte as char) + } else if byte.is_ascii_graphic() { + write!(f, "{}", byte as char) + } else { + write!(f, "\\{:03}", byte) + } + }) + } +} + +impl fmt::Debug for Label { + /// Print a label for debugging purposes. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Label") + .field(&format_args!("{}", self)) + .finish() + } +} diff --git a/src/new_base/name/mod.rs b/src/new_base/name/mod.rs index e288f9ee4..1cc63e1cd 100644 --- a/src/new_base/name/mod.rs +++ b/src/new_base/name/mod.rs @@ -13,3 +13,6 @@ //! The owner of `example.org` is thus responsible for _every_ domain ending //! with the `.example.org` suffix. The reverse order in which this hierarchy //! is expressed can sometimes be confusing. + +mod label; +pub use label::Label; From 66c4d198d911f25800e9bc2bf1c97b85943b1e42 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 10 Dec 2024 15:08:43 +0100 Subject: [PATCH 004/167] [new_base] Add module 'message' --- Cargo.lock | 26 +++- Cargo.toml | 6 + src/new_base/message.rs | 284 ++++++++++++++++++++++++++++++++++++++++ src/new_base/mod.rs | 1 + 4 files changed, 315 insertions(+), 2 deletions(-) create mode 100644 src/new_base/message.rs diff --git a/Cargo.lock b/Cargo.lock index 7f844fa92..7506702e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -278,6 +278,8 @@ dependencies = [ "tracing", "tracing-subscriber", "webpki-roots", + "zerocopy 0.8.13", + "zerocopy-derive 0.8.13", ] [[package]] @@ -797,7 +799,7 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -1690,7 +1692,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67914ab451f3bfd2e69e5e9d2ef3858484e7074d63f204fd166ec391b54de21d" +dependencies = [ + "zerocopy-derive 0.8.13", ] [[package]] @@ -1704,6 +1715,17 @@ dependencies = [ "syn", ] +[[package]] +name = "zerocopy-derive" +version = "0.8.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7988d73a4303ca289df03316bc490e934accf371af6bc745393cf3c2c5c4f25d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zeroize" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index 5fb61052e..0072d61fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,12 @@ tokio-stream = { version = "0.1.1", optional = true } tracing = { version = "0.1.40", optional = true } tracing-subscriber = { version = "0.3.18", optional = true, features = ["env-filter"] } +# 'zerocopy' provides simple derives for converting types to and from byte +# representations, along with network-endian integer primitives. These are +# used to define simple elements of DNS messages and their serialization. +zerocopy = "0.8" +zerocopy-derive = "0.8" + [features] default = ["std", "rand"] diff --git a/src/new_base/message.rs b/src/new_base/message.rs new file mode 100644 index 000000000..c07d605fa --- /dev/null +++ b/src/new_base/message.rs @@ -0,0 +1,284 @@ +//! DNS message headers. + +use core::fmt; + +use zerocopy::network_endian::U16; +use zerocopy_derive::*; + +//----------- Message -------------------------------------------------------- + +/// A DNS message. +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C, packed)] +pub struct Message { + /// The message header. + pub header: Header, + + /// The message contents. + pub contents: [u8], +} + +//----------- Header --------------------------------------------------------- + +/// A DNS message header. +#[derive( + Copy, + Clone, + Debug, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(C)] +pub struct Header { + /// A unique identifier for the message. + pub id: U16, + + /// Properties of the message. + pub flags: HeaderFlags, + + /// Counts of objects in the message. + pub counts: SectionCounts, +} + +//--- Formatting + +impl fmt::Display for Header { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} of ID {:04X} ({})", + self.flags, + self.id.get(), + self.counts + ) + } +} + +//----------- HeaderFlags ---------------------------------------------------- + +/// DNS message header flags. +#[derive( + Copy, + Clone, + Default, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct HeaderFlags { + inner: U16, +} + +//--- Interaction + +impl HeaderFlags { + /// Get the specified flag bit. + fn get_flag(&self, pos: u32) -> bool { + self.inner.get() & (1 << pos) != 0 + } + + /// Set the specified flag bit. + fn set_flag(mut self, pos: u32, value: bool) -> Self { + self.inner &= !(1 << pos); + self.inner |= (value as u16) << pos; + self + } + + /// The raw flags bits. + pub fn bits(&self) -> u16 { + self.inner.get() + } + + /// Whether this is a query. + pub fn is_query(&self) -> bool { + !self.get_flag(15) + } + + /// Whether this is a response. + pub fn is_response(&self) -> bool { + self.get_flag(15) + } + + /// The operation code. + pub fn opcode(&self) -> u8 { + (self.inner.get() >> 11) as u8 & 0xF + } + + /// The response code. + pub fn rcode(&self) -> u8 { + self.inner.get() as u8 & 0xF + } + + /// Construct a query. + pub fn query(mut self, opcode: u8) -> Self { + assert!(opcode < 16); + self.inner &= !(0xF << 11); + self.inner |= (opcode as u16) << 11; + self.set_flag(15, false) + } + + /// Construct a response. + pub fn respond(mut self, rcode: u8) -> Self { + assert!(rcode < 16); + self.inner &= !0xF; + self.inner |= rcode as u16; + self.set_flag(15, true) + } + + /// Whether this is an authoritative answer. + pub fn is_authoritative(&self) -> bool { + self.get_flag(10) + } + + /// Mark this as an authoritative answer. + pub fn set_authoritative(self, value: bool) -> Self { + self.set_flag(10, value) + } + + /// Whether this message is truncated. + pub fn is_truncated(&self) -> bool { + self.get_flag(9) + } + + /// Mark this message as truncated. + pub fn set_truncated(self, value: bool) -> Self { + self.set_flag(9, value) + } + + /// Whether the server should query recursively. + pub fn should_recurse(&self) -> bool { + self.get_flag(8) + } + + /// Direct the server to query recursively. + pub fn request_recursion(self, value: bool) -> Self { + self.set_flag(8, value) + } + + /// Whether the server supports recursion. + pub fn can_recurse(&self) -> bool { + self.get_flag(7) + } + + /// Indicate support for recursive queries. + pub fn support_recursion(self, value: bool) -> Self { + self.set_flag(7, value) + } +} + +//--- Formatting + +impl fmt::Debug for HeaderFlags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HeaderFlags") + .field("is_response (qr)", &self.is_response()) + .field("opcode", &self.opcode()) + .field("is_authoritative (aa)", &self.is_authoritative()) + .field("is_truncated (tc)", &self.is_truncated()) + .field("should_recurse (rd)", &self.should_recurse()) + .field("can_recurse (ra)", &self.can_recurse()) + .field("rcode", &self.rcode()) + .field("bits", &self.inner.get()) + .finish() + } +} + +impl fmt::Display for HeaderFlags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_query() { + if self.should_recurse() { + f.write_str("recursive ")?; + } + write!(f, "query (opcode {})", self.opcode())?; + } else { + if self.is_authoritative() { + f.write_str("authoritative ")?; + } + if self.should_recurse() && self.can_recurse() { + f.write_str("recursive ")?; + } + write!(f, "response (rcode {})", self.rcode())?; + } + + if self.is_truncated() { + f.write_str(" (message truncated)")?; + } + + Ok(()) + } +} + +//----------- SectionCounts -------------------------------------------------- + +/// Counts of objects in a DNS message. +#[derive( + Copy, + Clone, + Debug, + Default, + PartialEq, + Eq, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(C)] +pub struct SectionCounts { + /// The number of questions in the message. + pub questions: U16, + + /// The number of answer records in the message. + pub answers: U16, + + /// The number of name server records in the message. + pub authorities: U16, + + /// The number of additional records in the message. + pub additional: U16, +} + +//--- Formatting + +impl fmt::Display for SectionCounts { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut some = false; + + for (num, single, many) in [ + (self.questions.get(), "question", "questions"), + (self.answers.get(), "answer", "answers"), + (self.authorities.get(), "authority", "authorities"), + (self.additional.get(), "additional", "additional"), + ] { + // Add a comma if we have printed something before. + if some && num > 0 { + f.write_str(", ")?; + } + + // Print a count of this section. + match num { + 0 => {} + 1 => write!(f, "1 {single}")?, + n => write!(f, "{n} {many}")?, + } + + some |= num > 0; + } + + if !some { + f.write_str("empty")?; + } + + Ok(()) + } +} diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index c29a0b49f..368416354 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -4,4 +4,5 @@ //! with DNS. Most importantly, it provides functionality for parsing and //! building DNS messages on the wire. +pub mod message; pub mod name; From 48051b46abb396ac089b74854cff2a3836bcb481 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 10 Dec 2024 15:09:49 +0100 Subject: [PATCH 005/167] [new_base/name/label] Use 'zerocopy' --- src/new_base/name/label.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index d7d83b6a1..296a167fe 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -1,16 +1,19 @@ //! Labels in domain names. -//----------- Label ---------------------------------------------------------- - use core::{ cmp::Ordering, fmt, hash::{Hash, Hasher}, }; +use zerocopy_derive::*; + +//----------- Label ---------------------------------------------------------- + /// A label in a domain name. /// /// A label contains up to 63 bytes of arbitrary data. +#[derive(IntoBytes, Immutable, Unaligned)] #[repr(transparent)] pub struct Label([u8]); From be78a8f201e375645eef3ddf54b84a8683aaf60a Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 10 Dec 2024 15:26:00 +0100 Subject: [PATCH 006/167] [new_base] Add module 'parse' --- src/new_base/mod.rs | 6 ++++- src/new_base/parse.rs | 51 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 src/new_base/parse.rs diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 368416354..8a60f64c9 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -4,5 +4,9 @@ //! with DNS. Most importantly, it provides functionality for parsing and //! building DNS messages on the wire. -pub mod message; +mod message; +pub use message::{Header, HeaderFlags, Message, SectionCounts}; + pub mod name; + +pub mod parse; diff --git a/src/new_base/parse.rs b/src/new_base/parse.rs new file mode 100644 index 000000000..2a1697fb2 --- /dev/null +++ b/src/new_base/parse.rs @@ -0,0 +1,51 @@ +//! Parsing DNS messages from the wire format. + +use core::fmt; + +//----------- Low-level parsing traits --------------------------------------- + +/// Parsing from the start of a byte string. +pub trait SplitFrom<'a>: Sized { + /// Parse a value of [`Self`] from the start of the byte string. + /// + /// If parsing is successful, the parsed value and the rest of the string + /// are returned. Otherwise, a [`ParseError`] is returned. + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError>; +} + +/// Parsing from a byte string. +pub trait ParseFrom<'a>: Sized { + /// Parse a value of [`Self`] from the given byte string. + /// + /// If parsing is successful, the parsed value is returned. Otherwise, a + /// [`ParseError`] is returned. + fn parse_from(bytes: &'a [u8]) -> Result; +} + +//----------- ParseError ----------------------------------------------------- + +/// A DNS parsing error. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ParseError; + +//--- Formatting + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("DNS data could not be parsed from the wire format") + } +} + +//--- Conversion from 'zerocopy' errors + +impl From> for ParseError { + fn from(_: zerocopy::ConvertError) -> Self { + Self + } +} + +impl From> for ParseError { + fn from(_: zerocopy::SizeError) -> Self { + Self + } +} From 7a1a847717912a30a43a22b22c3ffd96799f12b3 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 10 Dec 2024 15:26:16 +0100 Subject: [PATCH 007/167] [new_base/name] Add module 'parsed' --- src/new_base/name/mod.rs | 3 + src/new_base/name/parsed.rs | 131 ++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 src/new_base/name/parsed.rs diff --git a/src/new_base/name/mod.rs b/src/new_base/name/mod.rs index 1cc63e1cd..9ee96824a 100644 --- a/src/new_base/name/mod.rs +++ b/src/new_base/name/mod.rs @@ -16,3 +16,6 @@ mod label; pub use label::Label; + +mod parsed; +pub use parsed::ParsedName; diff --git a/src/new_base/name/parsed.rs b/src/new_base/name/parsed.rs new file mode 100644 index 000000000..abf592e5d --- /dev/null +++ b/src/new_base/name/parsed.rs @@ -0,0 +1,131 @@ +//! Domain names encoded in DNS messages. + +use zerocopy_derive::*; + +use crate::new_base::parse::{ParseError, ParseFrom, SplitFrom}; + +//----------- ParsedName ----------------------------------------------------- + +/// A domain name in a DNS message. +#[derive(Debug, IntoBytes, Immutable, Unaligned)] +#[repr(transparent)] +pub struct ParsedName([u8]); + +//--- Constants + +impl ParsedName { + /// The maximum size of a parsed domain name in the wire format. + /// + /// This can occur if a compression pointer is used to point to a root + /// name, even though such a representation is longer than copying the + /// root label into the name. + pub const MAX_SIZE: usize = 256; + + /// The root name. + pub const ROOT: &'static Self = { + // SAFETY: A root label is the shortest valid name. + unsafe { Self::from_bytes_unchecked(&[0u8]) } + }; +} + +//--- Construction + +impl ParsedName { + /// Assume a byte string is a valid [`ParsedName`]. + /// + /// # Safety + /// + /// The byte string must be correctly encoded in the wire format, and + /// within the size restriction (256 bytes or fewer). It must end with a + /// root label or a compression pointer. + pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Self { + // SAFETY: 'ParsedName' is 'repr(transparent)' to '[u8]', so casting a + // '[u8]' into a 'ParsedName' is sound. + core::mem::transmute(bytes) + } +} + +//--- Inspection + +impl ParsedName { + /// The size of this name in the wire format. + #[allow(clippy::len_without_is_empty)] + pub const fn len(&self) -> usize { + self.0.len() + } + + /// Whether this is the root label. + pub const fn is_root(&self) -> bool { + self.0.len() == 1 + } + + /// Whether this is a compression pointer. + pub const fn is_pointer(&self) -> bool { + self.0.len() == 2 + } + + /// The wire format representation of the name. + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } +} + +//--- Parsing + +impl<'a> SplitFrom<'a> for &'a ParsedName { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + // Iterate through the labels in the name. + let mut index = 0usize; + loop { + if index >= ParsedName::MAX_SIZE || index >= bytes.len() { + return Err(ParseError); + } + let length = bytes[index]; + if length == 0 { + // This was the root label. + index += 1; + break; + } else if length < 0x40 { + // This was the length of the label. + index += 1 + length as usize; + } else if length >= 0xC0 { + // This was a compression pointer. + if index + 1 >= bytes.len() { + return Err(ParseError); + } + index += 2; + break; + } else { + // This was a reserved or deprecated label type. + return Err(ParseError); + } + } + + let (name, bytes) = bytes.split_at(index); + // SAFETY: 'bytes' has been confirmed to be correctly encoded. + Ok((unsafe { ParsedName::from_bytes_unchecked(name) }, bytes)) + } +} + +impl<'a> ParseFrom<'a> for &'a ParsedName { + fn parse_from(bytes: &'a [u8]) -> Result { + Self::split_from(bytes).and_then(|(name, rest)| { + rest.is_empty().then_some(name).ok_or(ParseError) + }) + } +} + +//--- Conversion to and from bytes + +impl AsRef<[u8]> for ParsedName { + /// The bytes in the name in the wire format. + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl<'a> From<&'a ParsedName> for &'a [u8] { + fn from(name: &'a ParsedName) -> Self { + name.as_bytes() + } +} From ed95534b0ea589a8ec965443a87c066522fd6113 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 10 Dec 2024 15:31:52 +0100 Subject: [PATCH 008/167] [new_base] Add module 'question' --- src/new_base/mod.rs | 3 ++ src/new_base/question.rs | 104 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 src/new_base/question.rs diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 8a60f64c9..499187fb9 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -9,4 +9,7 @@ pub use message::{Header, HeaderFlags, Message, SectionCounts}; pub mod name; +mod question; +pub use question::{QClass, QType, Question}; + pub mod parse; diff --git a/src/new_base/question.rs b/src/new_base/question.rs new file mode 100644 index 000000000..16e388c1c --- /dev/null +++ b/src/new_base/question.rs @@ -0,0 +1,104 @@ +//! DNS questions. + +use zerocopy::{network_endian::U16, FromBytes}; +use zerocopy_derive::*; + +use super::{ + name::ParsedName, + parse::{ParseError, ParseFrom, SplitFrom}, +}; + +//----------- Question ------------------------------------------------------- + +/// A DNS question. +pub struct Question<'a> { + /// The domain name being requested. + pub qname: &'a ParsedName, + + /// The type of the requested records. + pub qtype: QType, + + /// The class of the requested records. + pub qclass: QClass, +} + +//--- Construction + +impl<'a> Question<'a> { + /// Construct a new [`Question`]. + pub fn new(qname: &'a ParsedName, qtype: QType, qclass: QClass) -> Self { + Self { + qname, + qtype, + qclass, + } + } +} + +//--- Parsing + +impl<'a> SplitFrom<'a> for Question<'a> { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (qname, rest) = <&ParsedName>::split_from(bytes)?; + let (qtype, rest) = QType::read_from_prefix(rest)?; + let (qclass, rest) = QClass::read_from_prefix(rest)?; + Ok((Self::new(qname, qtype, qclass), rest)) + } +} + +impl<'a> ParseFrom<'a> for Question<'a> { + fn parse_from(bytes: &'a [u8]) -> Result { + let (qname, rest) = <&ParsedName>::split_from(bytes)?; + let (qtype, rest) = QType::read_from_prefix(rest)?; + let qclass = QClass::read_from_bytes(rest)?; + Ok(Self::new(qname, qtype, qclass)) + } +} + +//----------- QType ---------------------------------------------------------- + +/// The type of a question. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct QType { + /// The type code. + pub code: U16, +} + +//----------- QClass --------------------------------------------------------- + +/// The class of a question. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct QClass { + /// The class code. + pub code: U16, +} From 37bc7d2145791de17d535e4b70812c76eb086fea Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 11 Dec 2024 15:59:23 +0100 Subject: [PATCH 009/167] Add module 'record' --- src/new_base/mod.rs | 3 + src/new_base/record/mod.rs | 155 +++++++++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 src/new_base/record/mod.rs diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 499187fb9..f7baced2b 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -12,4 +12,7 @@ pub mod name; mod question; pub use question::{QClass, QType, Question}; +pub mod record; +pub use record::Record; + pub mod parse; diff --git a/src/new_base/record/mod.rs b/src/new_base/record/mod.rs new file mode 100644 index 000000000..fc348b710 --- /dev/null +++ b/src/new_base/record/mod.rs @@ -0,0 +1,155 @@ +//! DNS records. + +use zerocopy::{ + network_endian::{U16, U32}, + FromBytes, +}; +use zerocopy_derive::*; + +use super::{ + name::ParsedName, + parse::{ParseError, ParseFrom, SplitFrom}, +}; + +//----------- Record --------------------------------------------------------- + +/// An unparsed DNS record. +pub struct Record<'a> { + /// The name of the record. + pub rname: &'a ParsedName, + + /// The type of the record. + pub rtype: RType, + + /// The class of the record. + pub rclass: RClass, + + /// How long the record is reliable for. + pub ttl: TTL, + + /// Unparsed record data. + pub rdata: &'a [u8], +} + +//--- Construction + +impl<'a> Record<'a> { + /// Construct a new [`Record`]. + pub fn new( + rname: &'a ParsedName, + rtype: RType, + rclass: RClass, + ttl: TTL, + rdata: &'a [u8], + ) -> Self { + Self { + rname, + rtype, + rclass, + ttl, + rdata, + } + } +} + +//--- Parsing + +impl<'a> SplitFrom<'a> for Record<'a> { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (rname, rest) = <&ParsedName>::split_from(bytes)?; + let (rtype, rest) = RType::read_from_prefix(rest)?; + let (rclass, rest) = RClass::read_from_prefix(rest)?; + let (ttl, rest) = TTL::read_from_prefix(rest)?; + let (size, rest) = U16::read_from_prefix(rest)?; + let size = size.get() as usize; + let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; + + Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) + } +} + +impl<'a> ParseFrom<'a> for Record<'a> { + fn parse_from(bytes: &'a [u8]) -> Result { + let (rname, rest) = <&ParsedName>::split_from(bytes)?; + let (rtype, rest) = RType::read_from_prefix(rest)?; + let (rclass, rest) = RClass::read_from_prefix(rest)?; + let (ttl, rest) = TTL::read_from_prefix(rest)?; + let (size, rest) = U16::read_from_prefix(rest)?; + let size = size.get() as usize; + let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; + + Ok(Self::new(rname, rtype, rclass, ttl, rdata)) + } +} + +//----------- RType ---------------------------------------------------------- + +/// The type of a record. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct RType { + /// The type code. + pub code: U16, +} + +//----------- RClass --------------------------------------------------------- + +/// The class of a record. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct RClass { + /// The class code. + pub code: U16, +} + +//----------- TTL ------------------------------------------------------------ + +/// How long a record can be cached. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct TTL { + /// The underlying value. + pub value: U32, +} From 26653c391c8c936ccfa7a06fe083e8a7144643d5 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 11 Dec 2024 15:59:38 +0100 Subject: [PATCH 010/167] [new_base] Add high-level parsing traits --- src/new_base/parse/message.rs | 49 ++++++++ src/new_base/{parse.rs => parse/mod.rs} | 34 ++++++ src/new_base/parse/question.rs | 148 ++++++++++++++++++++++++ src/new_base/parse/record.rs | 148 ++++++++++++++++++++++++ 4 files changed, 379 insertions(+) create mode 100644 src/new_base/parse/message.rs rename src/new_base/{parse.rs => parse/mod.rs} (62%) create mode 100644 src/new_base/parse/question.rs create mode 100644 src/new_base/parse/record.rs diff --git a/src/new_base/parse/message.rs b/src/new_base/parse/message.rs new file mode 100644 index 000000000..eaea9845d --- /dev/null +++ b/src/new_base/parse/message.rs @@ -0,0 +1,49 @@ +//! Parsing DNS messages. + +use core::ops::ControlFlow; + +use crate::new_base::{Header, Question, Record}; + +/// A type that can be constructed by parsing a DNS message. +pub trait ParseMessage<'a>: Sized { + /// The type of visitors for incrementally building the output. + type Visitor: VisitMessagePart<'a>; + + /// The type of errors from converting a visitor into [`Self`]. + // TODO: Just use 'Visitor::Error'? + type Error; + + /// Construct a visitor, providing the message header. + fn make_visitor(header: &'a Header) + -> Result; + + /// Convert a visitor back to this type. + fn from_visitor(visitor: Self::Visitor) -> Result; +} + +/// A type that can visit the components of a DNS message. +pub trait VisitMessagePart<'a> { + /// The type of errors produced by visits. + type Error; + + /// Visit a component of the message. + fn visit( + &mut self, + component: MessagePart<'a>, + ) -> Result, Self::Error>; +} + +/// A component of a DNS message. +pub enum MessagePart<'a> { + /// A question. + Question(Question<'a>), + + /// An answer record. + Answer(Record<'a>), + + /// An authority record. + Authority(Record<'a>), + + /// An additional record. + Additional(Record<'a>), +} diff --git a/src/new_base/parse.rs b/src/new_base/parse/mod.rs similarity index 62% rename from src/new_base/parse.rs rename to src/new_base/parse/mod.rs index 2a1697fb2..a273717be 100644 --- a/src/new_base/parse.rs +++ b/src/new_base/parse/mod.rs @@ -2,6 +2,17 @@ use core::fmt; +use zerocopy::{FromBytes, Immutable, KnownLayout}; + +mod message; +pub use message::{MessagePart, ParseMessage, VisitMessagePart}; + +mod question; +pub use question::{ParseQuestion, ParseQuestions, VisitQuestion}; + +mod record; +pub use record::{ParseRecord, ParseRecords, VisitRecord}; + //----------- Low-level parsing traits --------------------------------------- /// Parsing from the start of a byte string. @@ -22,6 +33,29 @@ pub trait ParseFrom<'a>: Sized { fn parse_from(bytes: &'a [u8]) -> Result; } +//--- Carrying over 'zerocopy' traits + +// NOTE: We can't carry over 'read_from_prefix' because the trait impls would +// conflict. We kept 'ref_from_prefix' since it's more general. + +impl<'a, T: ?Sized> SplitFrom<'a> for &'a T +where + T: FromBytes + KnownLayout + Immutable, +{ + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + T::ref_from_prefix(bytes).map_err(|_| ParseError) + } +} + +impl<'a, T: ?Sized> ParseFrom<'a> for &'a T +where + T: FromBytes + KnownLayout + Immutable, +{ + fn parse_from(bytes: &'a [u8]) -> Result { + T::ref_from_bytes(bytes).map_err(|_| ParseError) + } +} + //----------- ParseError ----------------------------------------------------- /// A DNS parsing error. diff --git a/src/new_base/parse/question.rs b/src/new_base/parse/question.rs new file mode 100644 index 000000000..e08ea6283 --- /dev/null +++ b/src/new_base/parse/question.rs @@ -0,0 +1,148 @@ +//! Parsing DNS questions. + +use core::{convert::Infallible, ops::ControlFlow}; + +#[cfg(feature = "std")] +use std::boxed::Box; +#[cfg(feature = "std")] +use std::vec::Vec; + +use crate::new_base::Question; + +//----------- Trait definitions ---------------------------------------------- + +/// A type that can be constructed by parsing exactly one DNS question. +pub trait ParseQuestion<'a>: Sized { + /// The type of parse errors. + // TODO: Remove entirely? + type Error; + + /// Parse the given DNS question. + fn parse_question( + question: Question<'a>, + ) -> Result, Self::Error>; +} + +/// A type that can be constructed by parsing zero or more DNS questions. +pub trait ParseQuestions<'a>: Sized { + /// The type of visitors for incrementally building the output. + type Visitor: Default + VisitQuestion<'a>; + + /// The type of errors from converting a visitor into [`Self`]. + // TODO: Just use 'Visitor::Error'? Or remove entirely? + type Error; + + /// Convert a visitor back to this type. + fn from_visitor(visitor: Self::Visitor) -> Result; +} + +/// A type that can visit DNS questions. +pub trait VisitQuestion<'a> { + /// The type of errors produced by visits. + type Error; + + /// Visit a question. + fn visit_question( + &mut self, + question: Question<'a>, + ) -> Result, Self::Error>; +} + +//----------- Trait implementations ------------------------------------------ + +impl<'a> ParseQuestion<'a> for Question<'a> { + type Error = Infallible; + + fn parse_question( + question: Question<'a>, + ) -> Result, Self::Error> { + Ok(ControlFlow::Break(question)) + } +} + +//--- Impls for 'Option' + +impl<'a, T: ParseQuestion<'a>> ParseQuestion<'a> for Option { + type Error = T::Error; + + fn parse_question( + question: Question<'a>, + ) -> Result, Self::Error> { + Ok(match T::parse_question(question)? { + ControlFlow::Break(elem) => ControlFlow::Break(Some(elem)), + ControlFlow::Continue(()) => ControlFlow::Continue(()), + }) + } +} + +impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Option { + type Visitor = Option; + type Error = Infallible; + + fn from_visitor(visitor: Self::Visitor) -> Result { + Ok(visitor) + } +} + +impl<'a, T: ParseQuestion<'a>> VisitQuestion<'a> for Option { + type Error = T::Error; + + fn visit_question( + &mut self, + question: Question<'a>, + ) -> Result, Self::Error> { + if self.is_some() { + return Ok(ControlFlow::Continue(())); + } + + Ok(match T::parse_question(question)? { + ControlFlow::Break(elem) => { + *self = Some(elem); + ControlFlow::Break(()) + } + ControlFlow::Continue(()) => ControlFlow::Continue(()), + }) + } +} + +//--- Impls for 'Vec' + +#[cfg(feature = "std")] +impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Vec { + type Visitor = Vec; + type Error = Infallible; + + fn from_visitor(visitor: Self::Visitor) -> Result { + Ok(visitor) + } +} + +#[cfg(feature = "std")] +impl<'a, T: ParseQuestion<'a>> VisitQuestion<'a> for Vec { + type Error = T::Error; + + fn visit_question( + &mut self, + question: Question<'a>, + ) -> Result, Self::Error> { + Ok(match T::parse_question(question)? { + ControlFlow::Break(elem) => { + self.push(elem); + ControlFlow::Break(()) + } + ControlFlow::Continue(()) => ControlFlow::Continue(()), + }) + } +} + +//--- Impls for 'Box<[T]>' + +#[cfg(feature = "std")] +impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Box<[T]> { + type Visitor = Vec; + type Error = Infallible; + + fn from_visitor(visitor: Self::Visitor) -> Result { + Ok(visitor.into_boxed_slice()) + } +} diff --git a/src/new_base/parse/record.rs b/src/new_base/parse/record.rs new file mode 100644 index 000000000..c93f2f8d1 --- /dev/null +++ b/src/new_base/parse/record.rs @@ -0,0 +1,148 @@ +//! Parsing DNS records. + +use core::{convert::Infallible, ops::ControlFlow}; + +#[cfg(feature = "std")] +use std::boxed::Box; +#[cfg(feature = "std")] +use std::vec::Vec; + +use crate::new_base::Record; + +//----------- Trait definitions ---------------------------------------------- + +/// A type that can be constructed by parsing exactly one DNS record. +pub trait ParseRecord<'a>: Sized { + /// The type of parse errors. + // TODO: Remove entirely? + type Error; + + /// Parse the given DNS record. + fn parse_record( + record: Record<'a>, + ) -> Result, Self::Error>; +} + +/// A type that can be constructed by parsing zero or more DNS records. +pub trait ParseRecords<'a>: Sized { + /// The type of visitors for incrementally building the output. + type Visitor: Default + VisitRecord<'a>; + + /// The type of errors from converting a visitor into [`Self`]. + // TODO: Just use 'Visitor::Error'? Or remove entirely? + type Error; + + /// Convert a visitor back to this type. + fn from_visitor(visitor: Self::Visitor) -> Result; +} + +/// A type that can visit DNS records. +pub trait VisitRecord<'a> { + /// The type of errors produced by visits. + type Error; + + /// Visit a record. + fn visit_record( + &mut self, + record: Record<'a>, + ) -> Result, Self::Error>; +} + +//----------- Trait implementations ------------------------------------------ + +impl<'a> ParseRecord<'a> for Record<'a> { + type Error = Infallible; + + fn parse_record( + record: Record<'a>, + ) -> Result, Self::Error> { + Ok(ControlFlow::Break(record)) + } +} + +//--- Impls for 'Option' + +impl<'a, T: ParseRecord<'a>> ParseRecord<'a> for Option { + type Error = T::Error; + + fn parse_record( + record: Record<'a>, + ) -> Result, Self::Error> { + Ok(match T::parse_record(record)? { + ControlFlow::Break(elem) => ControlFlow::Break(Some(elem)), + ControlFlow::Continue(()) => ControlFlow::Continue(()), + }) + } +} + +impl<'a, T: ParseRecord<'a>> ParseRecords<'a> for Option { + type Visitor = Option; + type Error = Infallible; + + fn from_visitor(visitor: Self::Visitor) -> Result { + Ok(visitor) + } +} + +impl<'a, T: ParseRecord<'a>> VisitRecord<'a> for Option { + type Error = T::Error; + + fn visit_record( + &mut self, + record: Record<'a>, + ) -> Result, Self::Error> { + if self.is_some() { + return Ok(ControlFlow::Continue(())); + } + + Ok(match T::parse_record(record)? { + ControlFlow::Break(elem) => { + *self = Some(elem); + ControlFlow::Break(()) + } + ControlFlow::Continue(()) => ControlFlow::Continue(()), + }) + } +} + +//--- Impls for 'Vec' + +#[cfg(feature = "std")] +impl<'a, T: ParseRecord<'a>> ParseRecords<'a> for Vec { + type Visitor = Vec; + type Error = Infallible; + + fn from_visitor(visitor: Self::Visitor) -> Result { + Ok(visitor) + } +} + +#[cfg(feature = "std")] +impl<'a, T: ParseRecord<'a>> VisitRecord<'a> for Vec { + type Error = T::Error; + + fn visit_record( + &mut self, + record: Record<'a>, + ) -> Result, Self::Error> { + Ok(match T::parse_record(record)? { + ControlFlow::Break(elem) => { + self.push(elem); + ControlFlow::Break(()) + } + ControlFlow::Continue(()) => ControlFlow::Continue(()), + }) + } +} + +//--- Impls for 'Box<[T]>' + +#[cfg(feature = "std")] +impl<'a, T: ParseRecord<'a>> ParseRecords<'a> for Box<[T]> { + type Visitor = Vec; + type Error = Infallible; + + fn from_visitor(visitor: Self::Visitor) -> Result { + Ok(visitor.into_boxed_slice()) + } +} From e1c701ff288918dd6fb0ede61594d36e8fa4632a Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 11 Dec 2024 23:15:55 +0100 Subject: [PATCH 011/167] [new_base/name] Add module 'reversed' --- src/new_base/name/label.rs | 69 +++++++++++ src/new_base/name/mod.rs | 6 +- src/new_base/name/reversed.rs | 214 ++++++++++++++++++++++++++++++++++ 3 files changed, 286 insertions(+), 3 deletions(-) create mode 100644 src/new_base/name/reversed.rs diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index 296a167fe..48420df3a 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -4,10 +4,13 @@ use core::{ cmp::Ordering, fmt, hash::{Hash, Hasher}, + iter::FusedIterator, }; use zerocopy_derive::*; +use crate::new_base::parse::{ParseError, SplitFrom}; + //----------- Label ---------------------------------------------------------- /// A label in a domain name. @@ -47,6 +50,21 @@ impl Label { } } +//--- Parsing + +impl<'a> SplitFrom<'a> for &'a Label { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (&size, rest) = bytes.split_first().ok_or(ParseError)?; + if size < 64 && rest.len() >= size as usize { + let (label, rest) = bytes.split_at(1 + size as usize); + // SAFETY: 'label' begins with a valid length octet. + Ok((unsafe { Label::from_bytes_unchecked(label) }, rest)) + } else { + Err(ParseError) + } + } +} + //--- Inspection impl Label { @@ -182,3 +200,54 @@ impl fmt::Debug for Label { .finish() } } + +//----------- LabelIter ------------------------------------------------------ + +/// An iterator over encoded [`Label`]s. +#[derive(Clone)] +pub struct LabelIter<'a> { + /// The buffer being read from. + /// + /// It is assumed to contain valid encoded labels. + bytes: &'a [u8], +} + +//--- Construction + +impl<'a> LabelIter<'a> { + /// Construct a new [`LabelIter`]. + /// + /// The byte string must contain a sequence of valid encoded labels. + pub const unsafe fn new_unchecked(bytes: &'a [u8]) -> Self { + Self { bytes } + } +} + +//--- Inspection + +impl<'a> LabelIter<'a> { + /// The remaining labels. + pub const fn remaining(&self) -> &'a [u8] { + self.bytes + } +} + +//--- Iteration + +impl<'a> Iterator for LabelIter<'a> { + type Item = &'a Label; + + fn next(&mut self) -> Option { + if self.bytes.is_empty() { + return None; + } + + // SAFETY: 'bytes' is assumed to only contain valid labels. + let (head, tail) = + unsafe { <&Label>::split_from(self.bytes).unwrap_unchecked() }; + self.bytes = tail; + Some(head) + } +} + +impl FusedIterator for LabelIter<'_> {} diff --git a/src/new_base/name/mod.rs b/src/new_base/name/mod.rs index 9ee96824a..9270f4d5c 100644 --- a/src/new_base/name/mod.rs +++ b/src/new_base/name/mod.rs @@ -15,7 +15,7 @@ //! is expressed can sometimes be confusing. mod label; -pub use label::Label; +pub use label::{Label, LabelIter}; -mod parsed; -pub use parsed::ParsedName; +mod reversed; +pub use reversed::{RevName, RevNameBuf}; diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs new file mode 100644 index 000000000..6283c5322 --- /dev/null +++ b/src/new_base/name/reversed.rs @@ -0,0 +1,214 @@ +//! Reversed DNS names. + +use core::{ + borrow::Borrow, + cmp::Ordering, + hash::{Hash, Hasher}, + ops::Deref, +}; + +use zerocopy_derive::*; + +use super::LabelIter; + +//----------- RevName -------------------------------------------------------- + +/// A domain name in reversed order. +/// +/// Domain names are conventionally presented and encoded from the innermost +/// label to the root label. This ordering is inconvenient and difficult to +/// use, making many common operations (e.g. comparing and ordering domain +/// names) more computationally expensive. A [`RevName`] stores the labels in +/// reversed order for more efficient use. +#[derive(Immutable, Unaligned)] +#[repr(transparent)] +pub struct RevName([u8]); + +//--- Constants + +impl RevName { + /// The maximum size of a (reversed) domain name. + /// + /// This is the same as the maximum size of a regular domain name. + pub const MAX_SIZE: usize = 255; + + /// The root name. + pub const ROOT: &'static Self = { + // SAFETY: A root label is the shortest valid name. + unsafe { Self::from_bytes_unchecked(&[0u8]) } + }; +} + +//--- Construction + +impl RevName { + /// Assume a byte string is a valid [`RevName`]. + /// + /// # Safety + /// + /// The byte string must begin with a root label (0-value byte). It must + /// be followed by any number of encoded labels, as long as the size of + /// the whole string is 255 bytes or less. + pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Self { + // SAFETY: 'RevName' is 'repr(transparent)' to '[u8]', so casting a + // '[u8]' into a 'RevName' is sound. + core::mem::transmute(bytes) + } +} + +//--- Inspection + +impl RevName { + /// The size of this name in the wire format. + #[allow(clippy::len_without_is_empty)] + pub const fn len(&self) -> usize { + self.0.len() + } + + /// Whether this is the root label. + pub const fn is_root(&self) -> bool { + self.0.len() == 1 + } + + /// A byte representation of the [`RevName`]. + /// + /// Note that labels appear in reverse order to the _conventional_ format + /// (it thus starts with the root label). + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } + + /// The labels in the [`RevName`]. + /// + /// Note that labels appear in reverse order to the _conventional_ format + /// (it thus starts with the root label). + pub const fn labels(&self) -> LabelIter<'_> { + // SAFETY: A 'RevName' always contains valid encoded labels. + unsafe { LabelIter::new_unchecked(self.as_bytes()) } + } +} + +//--- Equality + +impl PartialEq for RevName { + fn eq(&self, that: &Self) -> bool { + // Instead of iterating labels, blindly iterate bytes. The locations + // of labels don't matter since we're testing everything for equality. + + // NOTE: Label lengths (which are less than 64) aren't affected by + // 'to_ascii_lowercase', so this method can be applied uniformly. + let this = self.as_bytes().iter().map(u8::to_ascii_lowercase); + let that = that.as_bytes().iter().map(u8::to_ascii_lowercase); + + this.eq(that) + } +} + +impl Eq for RevName {} + +//--- Comparison + +impl PartialOrd for RevName { + fn partial_cmp(&self, that: &Self) -> Option { + Some(self.cmp(that)) + } +} + +impl Ord for RevName { + fn cmp(&self, that: &Self) -> Ordering { + // Unfortunately, names cannot be compared bytewise. Labels are + // preceded by their length octets, but a longer label can be less + // than a shorter one if its first bytes are less. We are forced to + // compare lexicographically over labels. + self.labels().cmp(that.labels()) + } +} + +//--- Hashing + +impl Hash for RevName { + fn hash(&self, state: &mut H) { + for byte in self.as_bytes() { + // NOTE: Label lengths (which are less than 64) aren't affected by + // 'to_ascii_lowercase', so this method can be applied uniformly. + state.write_u8(byte.to_ascii_lowercase()) + } + } +} + +//----------- RevNameBuf ----------------------------------------------------- + +/// A 256-byte buffer containing a [`RevName`]. +#[derive(Immutable, Unaligned)] +#[repr(C)] // make layout compatible with '[u8; 256]' +pub struct RevNameBuf { + /// The position of the root label in the buffer. + offset: u8, + + /// The buffer containing the [`RevName`]. + buffer: [u8; 255], +} + +//--- Construction + +impl RevNameBuf { + /// Copy a [`RevName`] into a buffer. + pub fn copy_from(name: &RevName) -> Self { + let offset = 255 - name.len() as u8; + let mut buffer = [0u8; 255]; + buffer[offset as usize..].copy_from_slice(name.as_bytes()); + Self { offset, buffer } + } +} + +//--- Access to the underlying 'RevName' + +impl Deref for RevNameBuf { + type Target = RevName; + + fn deref(&self) -> &Self::Target { + let name = &self.buffer[self.offset as usize..]; + // SAFETY: A 'RevNameBuf' always contains a valid 'RevName'. + unsafe { RevName::from_bytes_unchecked(name) } + } +} + +impl Borrow for RevNameBuf { + fn borrow(&self) -> &RevName { + self + } +} + +impl AsRef for RevNameBuf { + fn as_ref(&self) -> &RevName { + self + } +} + +//--- Forwarding equality, comparison, and hashing + +impl PartialEq for RevNameBuf { + fn eq(&self, that: &Self) -> bool { + **self == **that + } +} + +impl Eq for RevNameBuf {} + +impl PartialOrd for RevNameBuf { + fn partial_cmp(&self, that: &Self) -> Option { + Some(self.cmp(that)) + } +} + +impl Ord for RevNameBuf { + fn cmp(&self, that: &Self) -> Ordering { + (**self).cmp(&**that) + } +} + +impl Hash for RevNameBuf { + fn hash(&self, state: &mut H) { + (**self).hash(state) + } +} From 7ef218dc5e07c80abdce2e5837e9d4f868bb0969 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 12 Dec 2024 00:13:40 +0100 Subject: [PATCH 012/167] [new_base/name/reversed] Implement complex parsing --- src/new_base/name/reversed.rs | 193 +++++++++++++++++++++++++++++++++- src/new_base/parse/mod.rs | 32 +++++- 2 files changed, 223 insertions(+), 2 deletions(-) diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 6283c5322..c884696e9 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -4,11 +4,19 @@ use core::{ borrow::Borrow, cmp::Ordering, hash::{Hash, Hasher}, - ops::Deref, + ops::{Deref, Range}, }; +use zerocopy::IntoBytes; use zerocopy_derive::*; +use crate::new_base::{ + parse::{ + ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + }, + Message, +}; + use super::LabelIter; //----------- RevName -------------------------------------------------------- @@ -152,6 +160,14 @@ pub struct RevNameBuf { //--- Construction impl RevNameBuf { + /// Construct an empty, invalid buffer. + fn empty() -> Self { + Self { + offset: 0, + buffer: [0; 255], + } + } + /// Copy a [`RevName`] into a buffer. pub fn copy_from(name: &RevName) -> Self { let offset = 255 - name.len() as u8; @@ -161,6 +177,181 @@ impl RevNameBuf { } } +//--- Parsing from DNS messages + +impl<'a> SplitFromMessage<'a> for RevNameBuf { + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + // NOTE: The input may be controlled by an attacker. Compression + // pointers can be arranged to cause loops or to access every byte in + // the message in random order. Instead of performing complex loop + // detection, which would probably perform allocations, we simply + // disallow a name to point to data _after_ it. Standard name + // compressors will never generate such pointers. + + let message = message.as_bytes(); + let mut buffer = Self::empty(); + + // Perform the first iteration early, to catch the end of the name. + let bytes = message.get(start..).ok_or(ParseError)?; + let (mut pointer, rest) = parse_segment(bytes, &mut buffer)?; + let orig_end = message.len() - rest.len(); + + // Traverse compression pointers. + while let Some(start) = pointer.map(usize::from) { + // Ensure the referenced position comes earlier. + if start >= start { + return Err(ParseError); + } + + // Keep going, from the referenced position. + let bytes = message.get(start..).ok_or(ParseError)?; + (pointer, _) = parse_segment(bytes, &mut buffer)?; + continue; + } + + // Stop and return the original end. + // NOTE: 'buffer' is now well-formed because we only stop when we + // reach a root label (which has been prepended into it). + Ok((buffer, orig_end)) + } +} + +impl<'a> ParseFromMessage<'a> for RevNameBuf { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + // See 'split_from_message()' for details. The only differences are + // in the range of the first iteration, and the check that the first + // iteration exactly covers the input range. + + let message = message.as_bytes(); + let mut buffer = Self::empty(); + + // Perform the first iteration early, to catch the end of the name. + let bytes = message.get(range.clone()).ok_or(ParseError)?; + let (mut pointer, rest) = parse_segment(bytes, &mut buffer)?; + + if !rest.is_empty() { + // The name didn't reach the end of the input range, fail. + return Err(ParseError); + } + + // Traverse compression pointers. + while let Some(start) = pointer.map(usize::from) { + // Ensure the referenced position comes earlier. + if start >= start { + return Err(ParseError); + } + + // Keep going, from the referenced position. + let bytes = message.get(start..).ok_or(ParseError)?; + (pointer, _) = parse_segment(bytes, &mut buffer)?; + continue; + } + + // NOTE: 'buffer' is now well-formed because we only stop when we + // reach a root label (which has been prepended into it). + Ok(buffer) + } +} + +/// Parse an encoded and potentially-compressed domain name, without +/// following any compression pointer. +fn parse_segment<'a>( + mut bytes: &'a [u8], + buffer: &mut RevNameBuf, +) -> Result<(Option, &'a [u8]), ParseError> { + loop { + let (&length, rest) = bytes.split_first().ok_or(ParseError)?; + if length == 0 { + // Found the root, stop. + buffer.prepend(&[0u8]); + return Ok((None, rest)); + } else if length < 64 { + // This looks like a regular label. + + if rest.len() < length as usize { + // The input doesn't contain the whole label. + return Err(ParseError); + } else if buffer.offset < 2 + length { + // The output name would exceed 254 bytes (this isn't + // the root label, so it can't fill the 255th byte). + return Err(ParseError); + } + + let (label, rest) = bytes.split_at(1 + length as usize); + buffer.prepend(label); + bytes = rest; + } else if length >= 0xC0 { + // This looks like a compression pointer. + + let (&extra, rest) = rest.split_first().ok_or(ParseError)?; + let pointer = u16::from_be_bytes([length, extra]); + + // NOTE: We don't verify the pointer here, that's left to + // the caller (since they have to actually use it). + return Ok((Some(pointer & 0x3FFF), rest)); + } else { + // This is an invalid or deprecated label type. + return Err(ParseError); + } + } +} + +//--- Parsing from general byte strings + +impl<'a> SplitFrom<'a> for RevNameBuf { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let mut buffer = Self::empty(); + + let (pointer, rest) = parse_segment(bytes, &mut buffer)?; + if pointer.is_some() { + // We can't follow compression pointers, so fail. + return Err(ParseError); + } + + // NOTE: 'buffer' is now well-formed because we only stop when we + // reach a root label (which has been prepended into it). + Ok((buffer, rest)) + } +} + +impl<'a> ParseFrom<'a> for RevNameBuf { + fn parse_from(bytes: &'a [u8]) -> Result { + let mut buffer = Self::empty(); + + let (pointer, rest) = parse_segment(bytes, &mut buffer)?; + if pointer.is_some() { + // We can't follow compression pointers, so fail. + return Err(ParseError); + } else if !rest.is_empty() { + // The name didn't reach the end of the input range, fail. + return Err(ParseError); + } + + // NOTE: 'buffer' is now well-formed because we only stop when we + // reach a root label (which has been prepended into it). + Ok(buffer) + } +} + +//--- Interaction + +impl RevNameBuf { + /// Prepend bytes to this buffer. + /// + /// This is an internal convenience function used while building buffers. + fn prepend(&mut self, label: &[u8]) { + self.offset -= label.len() as u8; + self.buffer[self.offset as usize..][..label.len()] + .copy_from_slice(label); + } +} + //--- Access to the underlying 'RevName' impl Deref for RevNameBuf { diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index a273717be..fac5a8e9d 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -1,6 +1,6 @@ //! Parsing DNS messages from the wire format. -use core::fmt; +use core::{fmt, ops::Range}; use zerocopy::{FromBytes, Immutable, KnownLayout}; @@ -13,6 +13,36 @@ pub use question::{ParseQuestion, ParseQuestions, VisitQuestion}; mod record; pub use record::{ParseRecord, ParseRecords, VisitRecord}; +use super::Message; + +//----------- Message-aware parsing traits ----------------------------------- + +/// A type that can be parsed from a DNS message. +pub trait SplitFromMessage<'a>: Sized { + /// Parse a value of [`Self`] from the start of a byte string within a + /// particular DNS message. + /// + /// If parsing is successful, the parsed value and the rest of the string + /// are returned. Otherwise, a [`ParseError`] is returned. + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError>; +} + +/// A type that can be parsed from a string in a DNS message. +pub trait ParseFromMessage<'a>: Sized { + /// Parse a value of [`Self`] from a byte string within a particular DNS + /// message. + /// + /// If parsing is successful, the parsed value is returned. Otherwise, a + /// [`ParseError`] is returned. + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result; +} + //----------- Low-level parsing traits --------------------------------------- /// Parsing from the start of a byte string. From 3c0b4cd951499507db1cbbacbf537c3b5247f8e4 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 12 Dec 2024 00:23:31 +0100 Subject: [PATCH 013/167] [new_base/name] Add some 'Debug' impls --- src/new_base/name/label.rs | 16 ++++++++++++++++ src/new_base/name/reversed.rs | 26 ++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index 48420df3a..597a5eb92 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -251,3 +251,19 @@ impl<'a> Iterator for LabelIter<'a> { } impl FusedIterator for LabelIter<'_> {} + +//--- Formatting + +impl fmt::Debug for LabelIter<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Labels<'a>(&'a LabelIter<'a>); + + impl fmt::Debug for Labels<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.0.clone()).finish() + } + } + + f.debug_tuple("LabelIter").field(&Labels(self)).finish() + } +} diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index c884696e9..1c315ba34 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -3,6 +3,7 @@ use core::{ borrow::Borrow, cmp::Ordering, + fmt, hash::{Hash, Hasher}, ops::{Deref, Range}, }; @@ -144,6 +145,31 @@ impl Hash for RevName { } } +//--- Formatting + +impl fmt::Debug for RevName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct RevLabels<'a>(&'a RevName); + + impl fmt::Debug for RevLabels<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut first = true; + self.0.labels().try_for_each(|label| { + if !first { + f.write_str(".")?; + } else { + first = false; + } + + label.fmt(f) + }) + } + } + + f.debug_tuple("RevName").field(&RevLabels(self)).finish() + } +} + //----------- RevNameBuf ----------------------------------------------------- /// A 256-byte buffer containing a [`RevName`]. From 4b8cc79fd22a5f6f1bb19016112829835d6d747b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 12 Dec 2024 00:50:43 +0100 Subject: [PATCH 014/167] [new_base] Implement parsing for 'Question' and 'Record' --- src/new_base/name/reversed.rs | 2 +- src/new_base/parse/mod.rs | 36 +++++++- src/new_base/question.rs | 80 ++++++++++++++---- src/new_base/record/mod.rs | 150 +++++++++++++++++++++++++++++----- 4 files changed, 229 insertions(+), 39 deletions(-) diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 1c315ba34..6c58b7ada 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -328,7 +328,7 @@ fn parse_segment<'a>( } } -//--- Parsing from general byte strings +//--- Parsing from bytes impl<'a> SplitFrom<'a> for RevNameBuf { fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index fac5a8e9d..fba82d65c 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -2,7 +2,7 @@ use core::{fmt, ops::Range}; -use zerocopy::{FromBytes, Immutable, KnownLayout}; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; mod message; pub use message::{MessagePart, ParseMessage, VisitMessagePart}; @@ -43,6 +43,40 @@ pub trait ParseFromMessage<'a>: Sized { ) -> Result; } +//--- Carrying over 'zerocopy' traits + +// NOTE: We can't carry over 'read_from_prefix' because the trait impls would +// conflict. We kept 'ref_from_prefix' since it's more general. + +impl<'a, T: ?Sized> SplitFromMessage<'a> for &'a T +where + T: FromBytes + KnownLayout + Immutable, +{ + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let message = message.as_bytes(); + let bytes = message.get(start..).ok_or(ParseError)?; + let (this, rest) = T::ref_from_prefix(bytes)?; + Ok((this, message.len() - rest.len())) + } +} + +impl<'a, T: ?Sized> ParseFromMessage<'a> for &'a T +where + T: FromBytes + KnownLayout + Immutable, +{ + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + let message = message.as_bytes(); + let bytes = message.get(range).ok_or(ParseError)?; + Ok(T::ref_from_bytes(bytes)?) + } +} + //----------- Low-level parsing traits --------------------------------------- /// Parsing from the start of a byte string. diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 16e388c1c..121eedff4 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -1,19 +1,25 @@ //! DNS questions. -use zerocopy::{network_endian::U16, FromBytes}; +use core::ops::Range; + +use zerocopy::network_endian::U16; use zerocopy_derive::*; use super::{ - name::ParsedName, - parse::{ParseError, ParseFrom, SplitFrom}, + name::RevNameBuf, + parse::{ + ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + }, + Message, }; //----------- Question ------------------------------------------------------- /// A DNS question. -pub struct Question<'a> { +#[derive(Clone)] +pub struct Question { /// The domain name being requested. - pub qname: &'a ParsedName, + pub qname: N, /// The type of the requested records. pub qtype: QType, @@ -22,11 +28,14 @@ pub struct Question<'a> { pub qclass: QClass, } +/// An unparsed DNS question. +pub type UnparsedQuestion = Question; + //--- Construction -impl<'a> Question<'a> { +impl Question { /// Construct a new [`Question`]. - pub fn new(qname: &'a ParsedName, qtype: QType, qclass: QClass) -> Self { + pub fn new(qname: N, qtype: QType, qclass: QClass) -> Self { Self { qname, qtype, @@ -35,22 +44,61 @@ impl<'a> Question<'a> { } } -//--- Parsing +//--- Parsing from DNS messages + +impl<'a, N> SplitFromMessage<'a> for Question +where + N: SplitFromMessage<'a>, +{ + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let (qname, rest) = N::split_from_message(message, start)?; + let (&qtype, rest) = <&QType>::split_from_message(message, rest)?; + let (&qclass, rest) = <&QClass>::split_from_message(message, rest)?; + Ok((Self::new(qname, qtype, qclass), rest)) + } +} + +impl<'a, N> ParseFromMessage<'a> for Question +where + N: SplitFromMessage<'a>, +{ + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + let (qname, rest) = N::split_from_message(message, range.start)?; + let (&qtype, rest) = <&QType>::split_from_message(message, rest)?; + let &qclass = + <&QClass>::parse_from_message(message, rest..range.end)?; + Ok(Self::new(qname, qtype, qclass)) + } +} + +//--- Parsing from bytes -impl<'a> SplitFrom<'a> for Question<'a> { +impl<'a, N> SplitFrom<'a> for Question +where + N: SplitFrom<'a>, +{ fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - let (qname, rest) = <&ParsedName>::split_from(bytes)?; - let (qtype, rest) = QType::read_from_prefix(rest)?; - let (qclass, rest) = QClass::read_from_prefix(rest)?; + let (qname, rest) = N::split_from(bytes)?; + let (&qtype, rest) = <&QType>::split_from(rest)?; + let (&qclass, rest) = <&QClass>::split_from(rest)?; Ok((Self::new(qname, qtype, qclass), rest)) } } -impl<'a> ParseFrom<'a> for Question<'a> { +impl<'a, N> ParseFrom<'a> for Question +where + N: SplitFrom<'a>, +{ fn parse_from(bytes: &'a [u8]) -> Result { - let (qname, rest) = <&ParsedName>::split_from(bytes)?; - let (qtype, rest) = QType::read_from_prefix(rest)?; - let qclass = QClass::read_from_bytes(rest)?; + let (qname, rest) = N::split_from(bytes)?; + let (&qtype, rest) = <&QType>::split_from(rest)?; + let &qclass = <&QClass>::parse_from(rest)?; Ok(Self::new(qname, qtype, qclass)) } } diff --git a/src/new_base/record/mod.rs b/src/new_base/record/mod.rs index fc348b710..42336dc6f 100644 --- a/src/new_base/record/mod.rs +++ b/src/new_base/record/mod.rs @@ -1,22 +1,31 @@ //! DNS records. +use core::{ + borrow::Borrow, + ops::{Deref, Range}, +}; + use zerocopy::{ network_endian::{U16, U32}, - FromBytes, + FromBytes, IntoBytes, }; use zerocopy_derive::*; use super::{ - name::ParsedName, - parse::{ParseError, ParseFrom, SplitFrom}, + name::RevNameBuf, + parse::{ + ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + }, + Message, }; //----------- Record --------------------------------------------------------- -/// An unparsed DNS record. -pub struct Record<'a> { +/// A DNS record. +#[derive(Clone)] +pub struct Record { /// The name of the record. - pub rname: &'a ParsedName, + pub rname: N, /// The type of the record. pub rtype: RType, @@ -28,19 +37,22 @@ pub struct Record<'a> { pub ttl: TTL, /// Unparsed record data. - pub rdata: &'a [u8], + pub rdata: D, } +/// An unparsed DNS record. +pub type UnparsedRecord<'a> = Record; + //--- Construction -impl<'a> Record<'a> { +impl Record { /// Construct a new [`Record`]. pub fn new( - rname: &'a ParsedName, + rname: N, rtype: RType, rclass: RClass, ttl: TTL, - rdata: &'a [u8], + rdata: D, ) -> Self { Self { rname, @@ -52,31 +64,35 @@ impl<'a> Record<'a> { } } -//--- Parsing +//--- Parsing from bytes -impl<'a> SplitFrom<'a> for Record<'a> { +impl<'a, N, D> SplitFrom<'a> for Record +where + N: SplitFrom<'a>, + D: SplitFrom<'a>, +{ fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - let (rname, rest) = <&ParsedName>::split_from(bytes)?; + let (rname, rest) = N::split_from(bytes)?; let (rtype, rest) = RType::read_from_prefix(rest)?; let (rclass, rest) = RClass::read_from_prefix(rest)?; let (ttl, rest) = TTL::read_from_prefix(rest)?; - let (size, rest) = U16::read_from_prefix(rest)?; - let size = size.get() as usize; - let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; + let (rdata, rest) = D::split_from(rest)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) } } -impl<'a> ParseFrom<'a> for Record<'a> { +impl<'a, N, D> ParseFrom<'a> for Record +where + N: SplitFrom<'a>, + D: ParseFrom<'a>, +{ fn parse_from(bytes: &'a [u8]) -> Result { - let (rname, rest) = <&ParsedName>::split_from(bytes)?; + let (rname, rest) = N::split_from(bytes)?; let (rtype, rest) = RType::read_from_prefix(rest)?; let (rclass, rest) = RClass::read_from_prefix(rest)?; let (ttl, rest) = TTL::read_from_prefix(rest)?; - let (size, rest) = U16::read_from_prefix(rest)?; - let size = size.get() as usize; - let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; + let rdata = D::parse_from(rest)?; Ok(Self::new(rname, rtype, rclass, ttl, rdata)) } @@ -153,3 +169,95 @@ pub struct TTL { /// The underlying value. pub value: U32, } + +//----------- UnparsedRecordData --------------------------------------------- + +/// Unparsed DNS record data. +#[derive(Immutable, Unaligned)] +#[repr(transparent)] +pub struct UnparsedRecordData([u8]); + +//--- Construction + +impl UnparsedRecordData { + /// Assume a byte string is a valid [`UnparsedRecordData`]. + /// + /// # Safety + /// + /// The byte string must be 65,535 bytes or shorter. + pub const unsafe fn new_unchecked(bytes: &[u8]) -> &Self { + // SAFETY: 'UnparsedRecordData' is 'repr(transparent)' to '[u8]', so + // casting a '[u8]' into an 'UnparsedRecordData' is sound. + core::mem::transmute(bytes) + } +} + +//--- Parsing from DNS messages + +impl<'a> SplitFromMessage<'a> for &'a UnparsedRecordData { + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let message = message.as_bytes(); + let bytes = message.get(start..).ok_or(ParseError)?; + let (this, rest) = Self::split_from(bytes)?; + Ok((this, message.len() - rest.len())) + } +} + +impl<'a> ParseFromMessage<'a> for &'a UnparsedRecordData { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + let message = message.as_bytes(); + let bytes = message.get(range).ok_or(ParseError)?; + Self::parse_from(bytes) + } +} + +//--- Parsing from bytes + +impl<'a> SplitFrom<'a> for &'a UnparsedRecordData { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (size, rest) = U16::read_from_prefix(bytes)?; + let size = size.get() as usize; + let (data, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; + // SAFETY: 'data.len() == size' which is a 'u16'. + let this = unsafe { UnparsedRecordData::new_unchecked(data) }; + Ok((this, rest)) + } +} + +impl<'a> ParseFrom<'a> for &'a UnparsedRecordData { + fn parse_from(bytes: &'a [u8]) -> Result { + let (size, rest) = U16::read_from_prefix(bytes)?; + let size = size.get() as usize; + let data = <[u8]>::ref_from_bytes_with_elems(rest, size)?; + // SAFETY: 'data.len() == size' which is a 'u16'. + Ok(unsafe { UnparsedRecordData::new_unchecked(data) }) + } +} + +//--- Access to the underlying bytes + +impl Deref for UnparsedRecordData { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Borrow<[u8]> for UnparsedRecordData { + fn borrow(&self) -> &[u8] { + self + } +} + +impl AsRef<[u8]> for UnparsedRecordData { + fn as_ref(&self) -> &[u8] { + self + } +} From b56260c95d4e58ab37eedf28410dd4416b6e5349 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 12 Dec 2024 00:57:39 +0100 Subject: [PATCH 015/167] [new_base/parse] Update with new question/record types --- src/new_base/mod.rs | 4 ++-- src/new_base/name/reversed.rs | 2 +- src/new_base/parse/message.rs | 10 ++++----- src/new_base/parse/question.rs | 38 +++++++++++++++++----------------- src/new_base/parse/record.rs | 18 ++++++++-------- 5 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index f7baced2b..a444989e4 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -10,9 +10,9 @@ pub use message::{Header, HeaderFlags, Message, SectionCounts}; pub mod name; mod question; -pub use question::{QClass, QType, Question}; +pub use question::{QClass, QType, Question, UnparsedQuestion}; pub mod record; -pub use record::Record; +pub use record::{Record, UnparsedRecord}; pub mod parse; diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 6c58b7ada..864a5e1bc 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -173,7 +173,7 @@ impl fmt::Debug for RevName { //----------- RevNameBuf ----------------------------------------------------- /// A 256-byte buffer containing a [`RevName`]. -#[derive(Immutable, Unaligned)] +#[derive(Clone, Immutable, Unaligned)] #[repr(C)] // make layout compatible with '[u8; 256]' pub struct RevNameBuf { /// The position of the root label in the buffer. diff --git a/src/new_base/parse/message.rs b/src/new_base/parse/message.rs index eaea9845d..1c964588a 100644 --- a/src/new_base/parse/message.rs +++ b/src/new_base/parse/message.rs @@ -2,7 +2,7 @@ use core::ops::ControlFlow; -use crate::new_base::{Header, Question, Record}; +use crate::new_base::{Header, UnparsedQuestion, UnparsedRecord}; /// A type that can be constructed by parsing a DNS message. pub trait ParseMessage<'a>: Sized { @@ -36,14 +36,14 @@ pub trait VisitMessagePart<'a> { /// A component of a DNS message. pub enum MessagePart<'a> { /// A question. - Question(Question<'a>), + Question(&'a UnparsedQuestion), /// An answer record. - Answer(Record<'a>), + Answer(&'a UnparsedRecord<'a>), /// An authority record. - Authority(Record<'a>), + Authority(&'a UnparsedRecord<'a>), /// An additional record. - Additional(Record<'a>), + Additional(&'a UnparsedRecord<'a>), } diff --git a/src/new_base/parse/question.rs b/src/new_base/parse/question.rs index e08ea6283..784cadc09 100644 --- a/src/new_base/parse/question.rs +++ b/src/new_base/parse/question.rs @@ -7,26 +7,26 @@ use std::boxed::Box; #[cfg(feature = "std")] use std::vec::Vec; -use crate::new_base::Question; +use crate::new_base::UnparsedQuestion; //----------- Trait definitions ---------------------------------------------- /// A type that can be constructed by parsing exactly one DNS question. -pub trait ParseQuestion<'a>: Sized { +pub trait ParseQuestion: Sized { /// The type of parse errors. // TODO: Remove entirely? type Error; /// Parse the given DNS question. fn parse_question( - question: Question<'a>, + question: &UnparsedQuestion, ) -> Result, Self::Error>; } /// A type that can be constructed by parsing zero or more DNS questions. -pub trait ParseQuestions<'a>: Sized { +pub trait ParseQuestions: Sized { /// The type of visitors for incrementally building the output. - type Visitor: Default + VisitQuestion<'a>; + type Visitor: Default + VisitQuestion; /// The type of errors from converting a visitor into [`Self`]. // TODO: Just use 'Visitor::Error'? Or remove entirely? @@ -37,36 +37,36 @@ pub trait ParseQuestions<'a>: Sized { } /// A type that can visit DNS questions. -pub trait VisitQuestion<'a> { +pub trait VisitQuestion { /// The type of errors produced by visits. type Error; /// Visit a question. fn visit_question( &mut self, - question: Question<'a>, + question: &UnparsedQuestion, ) -> Result, Self::Error>; } //----------- Trait implementations ------------------------------------------ -impl<'a> ParseQuestion<'a> for Question<'a> { +impl ParseQuestion for UnparsedQuestion { type Error = Infallible; fn parse_question( - question: Question<'a>, + question: &UnparsedQuestion, ) -> Result, Self::Error> { - Ok(ControlFlow::Break(question)) + Ok(ControlFlow::Break(question.clone())) } } //--- Impls for 'Option' -impl<'a, T: ParseQuestion<'a>> ParseQuestion<'a> for Option { +impl ParseQuestion for Option { type Error = T::Error; fn parse_question( - question: Question<'a>, + question: &UnparsedQuestion, ) -> Result, Self::Error> { Ok(match T::parse_question(question)? { ControlFlow::Break(elem) => ControlFlow::Break(Some(elem)), @@ -75,7 +75,7 @@ impl<'a, T: ParseQuestion<'a>> ParseQuestion<'a> for Option { } } -impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Option { +impl ParseQuestions for Option { type Visitor = Option; type Error = Infallible; @@ -84,12 +84,12 @@ impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Option { } } -impl<'a, T: ParseQuestion<'a>> VisitQuestion<'a> for Option { +impl VisitQuestion for Option { type Error = T::Error; fn visit_question( &mut self, - question: Question<'a>, + question: &UnparsedQuestion, ) -> Result, Self::Error> { if self.is_some() { return Ok(ControlFlow::Continue(())); @@ -108,7 +108,7 @@ impl<'a, T: ParseQuestion<'a>> VisitQuestion<'a> for Option { //--- Impls for 'Vec' #[cfg(feature = "std")] -impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Vec { +impl ParseQuestions for Vec { type Visitor = Vec; type Error = Infallible; @@ -118,12 +118,12 @@ impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Vec { } #[cfg(feature = "std")] -impl<'a, T: ParseQuestion<'a>> VisitQuestion<'a> for Vec { +impl VisitQuestion for Vec { type Error = T::Error; fn visit_question( &mut self, - question: Question<'a>, + question: &UnparsedQuestion, ) -> Result, Self::Error> { Ok(match T::parse_question(question)? { ControlFlow::Break(elem) => { @@ -138,7 +138,7 @@ impl<'a, T: ParseQuestion<'a>> VisitQuestion<'a> for Vec { //--- Impls for 'Box<[T]>' #[cfg(feature = "std")] -impl<'a, T: ParseQuestion<'a>> ParseQuestions<'a> for Box<[T]> { +impl ParseQuestions for Box<[T]> { type Visitor = Vec; type Error = Infallible; diff --git a/src/new_base/parse/record.rs b/src/new_base/parse/record.rs index c93f2f8d1..75e98a36a 100644 --- a/src/new_base/parse/record.rs +++ b/src/new_base/parse/record.rs @@ -7,7 +7,7 @@ use std::boxed::Box; #[cfg(feature = "std")] use std::vec::Vec; -use crate::new_base::Record; +use crate::new_base::UnparsedRecord; //----------- Trait definitions ---------------------------------------------- @@ -19,7 +19,7 @@ pub trait ParseRecord<'a>: Sized { /// Parse the given DNS record. fn parse_record( - record: Record<'a>, + record: &UnparsedRecord<'a>, ) -> Result, Self::Error>; } @@ -44,19 +44,19 @@ pub trait VisitRecord<'a> { /// Visit a record. fn visit_record( &mut self, - record: Record<'a>, + record: &UnparsedRecord<'a>, ) -> Result, Self::Error>; } //----------- Trait implementations ------------------------------------------ -impl<'a> ParseRecord<'a> for Record<'a> { +impl<'a> ParseRecord<'a> for UnparsedRecord<'a> { type Error = Infallible; fn parse_record( - record: Record<'a>, + record: &UnparsedRecord<'a>, ) -> Result, Self::Error> { - Ok(ControlFlow::Break(record)) + Ok(ControlFlow::Break(record.clone())) } } @@ -66,7 +66,7 @@ impl<'a, T: ParseRecord<'a>> ParseRecord<'a> for Option { type Error = T::Error; fn parse_record( - record: Record<'a>, + record: &UnparsedRecord<'a>, ) -> Result, Self::Error> { Ok(match T::parse_record(record)? { ControlFlow::Break(elem) => ControlFlow::Break(Some(elem)), @@ -89,7 +89,7 @@ impl<'a, T: ParseRecord<'a>> VisitRecord<'a> for Option { fn visit_record( &mut self, - record: Record<'a>, + record: &UnparsedRecord<'a>, ) -> Result, Self::Error> { if self.is_some() { return Ok(ControlFlow::Continue(())); @@ -123,7 +123,7 @@ impl<'a, T: ParseRecord<'a>> VisitRecord<'a> for Vec { fn visit_record( &mut self, - record: Record<'a>, + record: &UnparsedRecord<'a>, ) -> Result, Self::Error> { Ok(match T::parse_record(record)? { ControlFlow::Break(elem) => { From 1622e586b5c5273bb794f48abea89d490249f5ca Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 12 Dec 2024 01:00:48 +0100 Subject: [PATCH 016/167] [new_base/name] Fix bugs (thanks clippy) --- src/new_base/name/label.rs | 2 ++ src/new_base/name/reversed.rs | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index 597a5eb92..b93b32f80 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -217,6 +217,8 @@ pub struct LabelIter<'a> { impl<'a> LabelIter<'a> { /// Construct a new [`LabelIter`]. /// + /// # Safety + /// /// The byte string must contain a sequence of valid encoded labels. pub const unsafe fn new_unchecked(bytes: &'a [u8]) -> Self { Self { bytes } diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 864a5e1bc..60a640f91 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -226,15 +226,17 @@ impl<'a> SplitFromMessage<'a> for RevNameBuf { let orig_end = message.len() - rest.len(); // Traverse compression pointers. + let mut old_start = start; while let Some(start) = pointer.map(usize::from) { // Ensure the referenced position comes earlier. - if start >= start { + if start >= old_start { return Err(ParseError); } // Keep going, from the referenced position. let bytes = message.get(start..).ok_or(ParseError)?; (pointer, _) = parse_segment(bytes, &mut buffer)?; + old_start = start; continue; } @@ -267,15 +269,17 @@ impl<'a> ParseFromMessage<'a> for RevNameBuf { } // Traverse compression pointers. + let mut old_start = range.start; while let Some(start) = pointer.map(usize::from) { // Ensure the referenced position comes earlier. - if start >= start { + if start >= old_start { return Err(ParseError); } // Keep going, from the referenced position. let bytes = message.get(start..).ok_or(ParseError)?; (pointer, _) = parse_segment(bytes, &mut buffer)?; + old_start = start; continue; } From f27f922eba6ce469c4171dab08129c3413c46f4c Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 12 Dec 2024 01:04:03 +0100 Subject: [PATCH 017/167] [new_base] Make 'record' a private module --- src/new_base/mod.rs | 6 ++++-- src/new_base/{record/mod.rs => record.rs} | 0 2 files changed, 4 insertions(+), 2 deletions(-) rename src/new_base/{record/mod.rs => record.rs} (100%) diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index a444989e4..4307896fb 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -12,7 +12,9 @@ pub mod name; mod question; pub use question::{QClass, QType, Question, UnparsedQuestion}; -pub mod record; -pub use record::{Record, UnparsedRecord}; +mod record; +pub use record::{ + RClass, RType, Record, UnparsedRecord, UnparsedRecordData, TTL, +}; pub mod parse; diff --git a/src/new_base/record/mod.rs b/src/new_base/record.rs similarity index 100% rename from src/new_base/record/mod.rs rename to src/new_base/record.rs From dd123c12055c5b7dfbca2ed94cf53f2f4aac1da1 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 16 Dec 2024 13:05:28 +0100 Subject: [PATCH 018/167] Use 'zerocopy' 0.8.5 or newer It implements 'Hash' for the provided integer types. --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0072d61fa..9d078a5e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,8 +51,8 @@ tracing-subscriber = { version = "0.3.18", optional = true, features = ["env-fil # 'zerocopy' provides simple derives for converting types to and from byte # representations, along with network-endian integer primitives. These are # used to define simple elements of DNS messages and their serialization. -zerocopy = "0.8" -zerocopy-derive = "0.8" +zerocopy = "0.8.5" +zerocopy-derive = "0.8.5" [features] default = ["std", "rand"] From 46b2e45873879fc6b3b0019c0ec8b32f8b032c52 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 16 Dec 2024 14:33:26 +0100 Subject: [PATCH 019/167] Add module 'new_rdata' with most RFC 1035 types --- src/lib.rs | 1 + src/new_rdata/mod.rs | 3 + src/new_rdata/rfc1035.rs | 319 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 323 insertions(+) create mode 100644 src/new_rdata/mod.rs create mode 100644 src/new_rdata/rfc1035.rs diff --git a/src/lib.rs b/src/lib.rs index e9aef12b8..b2f7ac66c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -194,6 +194,7 @@ pub mod base; pub mod dep; pub mod net; pub mod new_base; +pub mod new_rdata; pub mod rdata; pub mod resolv; pub mod sign; diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs new file mode 100644 index 000000000..54afb39ee --- /dev/null +++ b/src/new_rdata/mod.rs @@ -0,0 +1,3 @@ +//! Record data types. + +pub mod rfc1035; diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs new file mode 100644 index 000000000..b8d893dff --- /dev/null +++ b/src/new_rdata/rfc1035.rs @@ -0,0 +1,319 @@ +//! Core record data types. + +use core::{fmt, net::Ipv4Addr, ops::Range, str::FromStr}; + +use zerocopy::network_endian::{U16, U32}; +use zerocopy_derive::*; + +use crate::new_base::{ + parse::{ParseError, ParseFromMessage, SplitFromMessage}, + Message, +}; + +//----------- A -------------------------------------------------------------- + +/// The IPv4 address of a host responsible for this domain. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct A { + /// The IPv4 address octets. + pub octets: [u8; 4], +} + +//--- Converting to and from 'Ipv4Addr' + +impl From for A { + fn from(value: Ipv4Addr) -> Self { + Self { + octets: value.octets(), + } + } +} + +impl From for Ipv4Addr { + fn from(value: A) -> Self { + Self::from(value.octets) + } +} + +//--- Parsing from a string + +impl FromStr for A { + type Err = ::Err; + + fn from_str(s: &str) -> Result { + Ipv4Addr::from_str(s).map(A::from) + } +} + +//--- Formatting + +impl fmt::Display for A { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Ipv4Addr::from(*self).fmt(f) + } +} + +//----------- Ns ------------------------------------------------------------- + +/// The authoritative name server for this domain. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct Ns { + /// The name of the authoritative server. + pub name: N, +} + +//--- Parsing from DNS messages + +impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ns { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + N::parse_from_message(message, range).map(|name| Self { name }) + } +} + +//----------- Cname ---------------------------------------------------------- + +/// The canonical name for this domain. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct Cname { + /// The canonical name. + pub name: N, +} + +//--- Parsing from DNS messages + +impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Cname { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + N::parse_from_message(message, range).map(|name| Self { name }) + } +} + +//----------- Soa ------------------------------------------------------------ + +/// The start of a zone of authority. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct Soa { + /// The name server which provided this zone. + pub mname: N, + + /// The mailbox of the maintainer of this zone. + pub rname: N, + + /// The version number of the original copy of this zone. + // TODO: Define a dedicated serial number type. + pub serial: U32, + + /// The number of seconds to wait until refreshing the zone. + pub refresh: U32, + + /// The number of seconds to wait until retrying a failed refresh. + pub retry: U32, + + /// The number of seconds until the zone is considered expired. + pub expire: U32, + + /// The minimum TTL for any record in this zone. + pub minimum: U32, +} + +//--- Parsing from DNS messages + +impl<'a, N: SplitFromMessage<'a>> ParseFromMessage<'a> for Soa { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + let (mname, rest) = N::split_from_message(message, range.start)?; + let (rname, rest) = N::split_from_message(message, rest)?; + let (&serial, rest) = <&U32>::split_from_message(message, rest)?; + let (&refresh, rest) = <&U32>::split_from_message(message, rest)?; + let (&retry, rest) = <&U32>::split_from_message(message, rest)?; + let (&expire, rest) = <&U32>::split_from_message(message, rest)?; + let &minimum = <&U32>::parse_from_message(message, rest..range.end)?; + + Ok(Self { + mname, + rname, + serial, + refresh, + retry, + expire, + minimum, + }) + } +} + +//----------- Wks ------------------------------------------------------------ + +/// Well-known services supported on this domain. +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C, packed)] +pub struct Wks { + /// The address of the host providing these services. + pub address: A, + + /// The IP protocol number for the services (e.g. TCP). + pub protocol: u8, + + /// A bitset of supported well-known ports. + pub ports: [u8], +} + +//--- Formatting + +impl fmt::Debug for Wks { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Ports<'a>(&'a [u8]); + + impl fmt::Debug for Ports<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let entries = self + .0 + .iter() + .enumerate() + .flat_map(|(i, &b)| (0..8).map(move |j| (i, j, b))) + .filter(|(_, j, b)| b & (1 << j) != 0) + .map(|(i, j, _)| i * 8 + j); + + f.debug_set().entries(entries).finish() + } + } + + f.debug_struct("Wks") + .field("address", &Ipv4Addr::from(self.address)) + .field("protocol", &self.protocol) + .field("ports", &Ports(&self.ports)) + .finish() + } +} + +//----------- Ptr ------------------------------------------------------------ + +/// A pointer to another domain name. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct Ptr { + /// The referenced domain name. + pub name: N, +} + +//--- Parsing from DNS messages + +impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ptr { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + N::parse_from_message(message, range).map(|name| Self { name }) + } +} + +// TODO: MINFO, HINFO, and TXT records, which need 'CharStr'. + +//----------- Mx ------------------------------------------------------------- + +/// A host that can exchange mail for this domain. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(C)] +pub struct Mx { + /// The preference for this host over others. + pub preference: U16, + + /// The domain name of the mail exchanger. + pub exchange: N, +} + +//--- Parsing from DNS messages + +impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Mx { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + let (&preference, rest) = + <&U16>::split_from_message(message, range.start)?; + let exchange = N::parse_from_message(message, rest..range.end)?; + Ok(Self { + preference, + exchange, + }) + } +} From 9e238ae43263172fc60ddf59229e2b436db020cd Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 16 Dec 2024 14:43:57 +0100 Subject: [PATCH 020/167] [new_base] Define 'CharStr' --- src/new_base/charstr.rs | 78 +++++++++++++++++++++++++++++++++++++++++ src/new_base/mod.rs | 3 ++ 2 files changed, 81 insertions(+) create mode 100644 src/new_base/charstr.rs diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs new file mode 100644 index 000000000..62d289718 --- /dev/null +++ b/src/new_base/charstr.rs @@ -0,0 +1,78 @@ +//! DNS "character strings". + +use core::ops::Range; + +use zerocopy::IntoBytes; +use zerocopy_derive::*; + +use super::{ + parse::{ + ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + }, + Message, +}; + +//----------- CharStr -------------------------------------------------------- + +/// A DNS "character string". +#[derive(Immutable, Unaligned)] +#[repr(transparent)] +pub struct CharStr { + /// The underlying octets. + pub octets: [u8], +} + +//--- Parsing from DNS messages + +impl<'a> SplitFromMessage<'a> for &'a CharStr { + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let bytes = &message.as_bytes()[start..]; + let (this, rest) = Self::split_from(bytes)?; + Ok((this, bytes.len() - rest.len())) + } +} + +impl<'a> ParseFromMessage<'a> for &'a CharStr { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + message + .as_bytes() + .get(range) + .ok_or(ParseError) + .and_then(Self::parse_from) + } +} + +//--- Parsing from bytes + +impl<'a> SplitFrom<'a> for &'a CharStr { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (&length, rest) = bytes.split_first().ok_or(ParseError)?; + if length as usize > rest.len() { + return Err(ParseError); + } + let (bytes, rest) = rest.split_at(length as usize); + + // SAFETY: 'CharStr' is 'repr(transparent)' to '[u8]'. + Ok((unsafe { core::mem::transmute::<&[u8], Self>(bytes) }, rest)) + } +} + +impl<'a> ParseFrom<'a> for &'a CharStr { + fn parse_from(bytes: &'a [u8]) -> Result { + let (&length, rest) = bytes.split_first().ok_or(ParseError)?; + if length as usize != rest.len() { + return Err(ParseError); + } + + // SAFETY: 'CharStr' is 'repr(transparent)' to '[u8]'. + Ok(unsafe { core::mem::transmute::<&[u8], Self>(rest) }) + } +} + +// TODO: Formatting diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 4307896fb..428584b68 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -9,6 +9,9 @@ pub use message::{Header, HeaderFlags, Message, SectionCounts}; pub mod name; +mod charstr; +pub use charstr::CharStr; + mod question; pub use question::{QClass, QType, Question, UnparsedQuestion}; From bb63a1ee8ba95ef99f95721030ca4931b82fe860 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 16 Dec 2024 14:55:55 +0100 Subject: [PATCH 021/167] [new_rdata] Define 'Hinfo' --- src/new_rdata/rfc1035.rs | 49 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index b8d893dff..52bfd02df 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -2,12 +2,17 @@ use core::{fmt, net::Ipv4Addr, ops::Range, str::FromStr}; -use zerocopy::network_endian::{U16, U32}; +use zerocopy::{ + network_endian::{U16, U32}, + IntoBytes, +}; use zerocopy_derive::*; use crate::new_base::{ - parse::{ParseError, ParseFromMessage, SplitFromMessage}, - Message, + parse::{ + ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + }, + CharStr, Message, }; //----------- A -------------------------------------------------------------- @@ -272,7 +277,41 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ptr { } } -// TODO: MINFO, HINFO, and TXT records, which need 'CharStr'. +//----------- Hinfo ---------------------------------------------------------- + +/// Information about the host computer. +pub struct Hinfo<'a> { + /// The CPU type. + pub cpu: &'a CharStr, + + /// The OS type. + pub os: &'a CharStr, +} + +//--- Parsing from DNS messages + +impl<'a> ParseFromMessage<'a> for Hinfo<'a> { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + message + .as_bytes() + .get(range) + .ok_or(ParseError) + .and_then(Self::parse_from) + } +} + +//--- Parsing from bytes + +impl<'a> ParseFrom<'a> for Hinfo<'a> { + fn parse_from(bytes: &'a [u8]) -> Result { + let (cpu, rest) = <&CharStr>::split_from(bytes)?; + let os = <&CharStr>::parse_from(rest)?; + Ok(Self { cpu, os }) + } +} //----------- Mx ------------------------------------------------------------- @@ -317,3 +356,5 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Mx { }) } } + +// TODO: TXT records. From 5ec06dca109fbb81106ccea92358582f543addd8 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 25 Dec 2024 16:41:15 +0100 Subject: [PATCH 022/167] [new_rdata/rfc1035] Implement (basic) 'Txt' records --- src/new_rdata/rfc1035.rs | 42 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index 52bfd02df..80db47141 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -357,4 +357,44 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Mx { } } -// TODO: TXT records. +//----------- Txt ------------------------------------------------------------ + +/// Free-form text strings about this domain. +#[derive(IntoBytes, Immutable, Unaligned)] +#[repr(transparent)] +pub struct Txt { + /// The text strings, as concatenated [`CharStr`]s. + content: [u8], +} + +// TODO: Support for iterating over the contained 'CharStr's. + +//--- Parsing from DNS messages + +impl<'a> ParseFromMessage<'a> for &'a Txt { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + message + .as_bytes() + .get(range) + .ok_or(ParseError) + .and_then(Self::parse_from) + } +} + +//--- Parsing from bytes + +impl<'a> ParseFrom<'a> for &'a Txt { + fn parse_from(bytes: &'a [u8]) -> Result { + // NOTE: The input must contain at least one 'CharStr'. + let (_, mut rest) = <&CharStr>::split_from(bytes)?; + while !rest.is_empty() { + (_, rest) = <&CharStr>::split_from(rest)?; + } + + // SAFETY: 'Txt' is 'repr(transparent)' to '[u8]'. + Ok(unsafe { core::mem::transmute::<&'a [u8], Self>(bytes) }) + } +} From 6ad095eed19a5b82ea357a3c98edd108e6bb7c6b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 25 Dec 2024 16:48:49 +0100 Subject: [PATCH 023/167] [new_rdata/rfc1035] Add 'ParseFrom' impls where missing --- src/new_rdata/rfc1035.rs | 61 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index 80db47141..4d3a07c47 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -108,6 +108,14 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ns { } } +//--- Parsing from bytes + +impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ns { + fn parse_from(bytes: &'a [u8]) -> Result { + N::parse_from(bytes).map(|name| Self { name }) + } +} + //----------- Cname ---------------------------------------------------------- /// The canonical name for this domain. @@ -143,6 +151,14 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Cname { } } +//--- Parsing from bytes + +impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Cname { + fn parse_from(bytes: &'a [u8]) -> Result { + N::parse_from(bytes).map(|name| Self { name }) + } +} + //----------- Soa ------------------------------------------------------------ /// The start of a zone of authority. @@ -198,6 +214,30 @@ impl<'a, N: SplitFromMessage<'a>> ParseFromMessage<'a> for Soa { } } +//--- Parsing from bytes + +impl<'a, N: SplitFrom<'a>> ParseFrom<'a> for Soa { + fn parse_from(bytes: &'a [u8]) -> Result { + let (mname, rest) = N::split_from(bytes)?; + let (rname, rest) = N::split_from(rest)?; + let (&serial, rest) = <&U32>::split_from(rest)?; + let (&refresh, rest) = <&U32>::split_from(rest)?; + let (&retry, rest) = <&U32>::split_from(rest)?; + let (&expire, rest) = <&U32>::split_from(rest)?; + let &minimum = <&U32>::parse_from(rest)?; + + Ok(Self { + mname, + rname, + serial, + refresh, + retry, + expire, + minimum, + }) + } +} + //----------- Wks ------------------------------------------------------------ /// Well-known services supported on this domain. @@ -277,6 +317,14 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ptr { } } +//--- Parsing from bytes + +impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ptr { + fn parse_from(bytes: &'a [u8]) -> Result { + N::parse_from(bytes).map(|name| Self { name }) + } +} + //----------- Hinfo ---------------------------------------------------------- /// Information about the host computer. @@ -357,6 +405,19 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Mx { } } +//--- Parsing from bytes + +impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Mx { + fn parse_from(bytes: &'a [u8]) -> Result { + let (&preference, rest) = <&U16>::split_from(bytes)?; + let exchange = N::parse_from(rest)?; + Ok(Self { + preference, + exchange, + }) + } +} + //----------- Txt ------------------------------------------------------------ /// Free-form text strings about this domain. From 90e15bb7dae03605077eaa579eec1b6a3a6ee202 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 25 Dec 2024 17:17:07 +0100 Subject: [PATCH 024/167] [new_rdata/rfc1035] Don't use 'zerocopy' around names --- src/new_rdata/rfc1035.rs | 64 +++------------------------------------- 1 file changed, 4 insertions(+), 60 deletions(-) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index 4d3a07c47..9415f5362 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -76,21 +76,7 @@ impl fmt::Display for A { //----------- Ns ------------------------------------------------------------- /// The authoritative name server for this domain. -#[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - FromBytes, - IntoBytes, - KnownLayout, - Immutable, - Unaligned, -)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct Ns { /// The name of the authoritative server. @@ -119,21 +105,7 @@ impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ns { //----------- Cname ---------------------------------------------------------- /// The canonical name for this domain. -#[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - FromBytes, - IntoBytes, - KnownLayout, - Immutable, - Unaligned, -)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct Cname { /// The canonical name. @@ -285,21 +257,7 @@ impl fmt::Debug for Wks { //----------- Ptr ------------------------------------------------------------ /// A pointer to another domain name. -#[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - FromBytes, - IntoBytes, - KnownLayout, - Immutable, - Unaligned, -)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct Ptr { /// The referenced domain name. @@ -364,21 +322,7 @@ impl<'a> ParseFrom<'a> for Hinfo<'a> { //----------- Mx ------------------------------------------------------------- /// A host that can exchange mail for this domain. -#[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - FromBytes, - IntoBytes, - KnownLayout, - Immutable, - Unaligned, -)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(C)] pub struct Mx { /// The preference for this host over others. From c86a57e8373d487e5afc3bd57c2fa92d618ae550 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 25 Dec 2024 17:18:01 +0100 Subject: [PATCH 025/167] [new_base/charstr] Impl 'Eq' and 'Debug' --- src/new_base/charstr.rs | 52 ++++++++++++++++++++++++++++++++++++++-- src/new_rdata/rfc1035.rs | 1 + 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index 62d289718..dfe2bc8f9 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -1,6 +1,6 @@ //! DNS "character strings". -use core::ops::Range; +use core::{fmt, ops::Range}; use zerocopy::IntoBytes; use zerocopy_derive::*; @@ -75,4 +75,52 @@ impl<'a> ParseFrom<'a> for &'a CharStr { } } -// TODO: Formatting +//--- Equality + +impl PartialEq for CharStr { + fn eq(&self, other: &Self) -> bool { + self.octets.eq_ignore_ascii_case(&other.octets) + } +} + +impl Eq for CharStr {} + +//--- Formatting + +impl fmt::Debug for CharStr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use fmt::Write; + + struct Native<'a>(&'a [u8]); + impl fmt::Debug for Native<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("b\"")?; + for &b in self.0 { + f.write_str(match b { + b'"' => "\\\"", + b' ' => " ", + b'\n' => "\\n", + b'\r' => "\\r", + b'\t' => "\\t", + b'\\' => "\\\\", + + _ => { + if b.is_ascii_graphic() { + f.write_char(b as char)?; + } else { + write!(f, "\\x{:02X}", b)?; + } + continue; + } + })?; + } + f.write_char('"')?; + Ok(()) + } + } + + f.debug_struct("CharStr") + .field("content", &Native(&self.octets)) + .finish() + } +} diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index 9415f5362..e54c348be 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -286,6 +286,7 @@ impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ptr { //----------- Hinfo ---------------------------------------------------------- /// Information about the host computer. +#[derive(Clone, Debug, PartialEq, Eq)] pub struct Hinfo<'a> { /// The CPU type. pub cpu: &'a CharStr, From a607c7748790d6ee1850f161feeb58e57843faaf Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 25 Dec 2024 17:36:11 +0100 Subject: [PATCH 026/167] [new_base] Add module 'serial' --- src/new_base/mod.rs | 3 ++ src/new_base/serial.rs | 87 ++++++++++++++++++++++++++++++++++++++++ src/new_rdata/rfc1035.rs | 9 ++--- 3 files changed, 94 insertions(+), 5 deletions(-) create mode 100644 src/new_base/serial.rs diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 428584b68..3becedb29 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -21,3 +21,6 @@ pub use record::{ }; pub mod parse; + +mod serial; +pub use serial::Serial; diff --git a/src/new_base/serial.rs b/src/new_base/serial.rs new file mode 100644 index 000000000..fe00923c3 --- /dev/null +++ b/src/new_base/serial.rs @@ -0,0 +1,87 @@ +//! Serial number arithmetic. +//! +//! See [RFC 1982](https://datatracker.ietf.org/doc/html/rfc1982). + +use core::{ + cmp::Ordering, + fmt, + ops::{Add, AddAssign}, +}; + +use zerocopy::network_endian::U32; +use zerocopy_derive::*; + +//----------- Serial --------------------------------------------------------- + +/// A serial number. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct Serial(U32); + +//--- Addition + +impl Add for Serial { + type Output = Self; + + fn add(self, rhs: i32) -> Self::Output { + self.0.get().wrapping_add_signed(rhs).into() + } +} + +impl AddAssign for Serial { + fn add_assign(&mut self, rhs: i32) { + self.0 = self.0.get().wrapping_add_signed(rhs).into(); + } +} + +//--- Ordering + +impl PartialOrd for Serial { + fn partial_cmp(&self, other: &Self) -> Option { + let (lhs, rhs) = (self.0.get(), other.0.get()); + + if lhs == rhs { + Some(Ordering::Equal) + } else if lhs.abs_diff(rhs) == 1 << 31 { + None + } else if (lhs < rhs) ^ (lhs.abs_diff(rhs) > (1 << 31)) { + Some(Ordering::Less) + } else { + Some(Ordering::Greater) + } + } +} + +//--- Conversion to and from native integer types + +impl From for Serial { + fn from(value: u32) -> Self { + Self(U32::new(value)) + } +} + +impl From for u32 { + fn from(value: Serial) -> Self { + value.0.get() + } +} + +//--- Formatting + +impl fmt::Display for Serial { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.get().fmt(f) + } +} diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index e54c348be..4e09c9727 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -12,7 +12,7 @@ use crate::new_base::{ parse::{ ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, }, - CharStr, Message, + CharStr, Message, Serial, }; //----------- A -------------------------------------------------------------- @@ -143,8 +143,7 @@ pub struct Soa { pub rname: N, /// The version number of the original copy of this zone. - // TODO: Define a dedicated serial number type. - pub serial: U32, + pub serial: Serial, /// The number of seconds to wait until refreshing the zone. pub refresh: U32, @@ -168,7 +167,7 @@ impl<'a, N: SplitFromMessage<'a>> ParseFromMessage<'a> for Soa { ) -> Result { let (mname, rest) = N::split_from_message(message, range.start)?; let (rname, rest) = N::split_from_message(message, rest)?; - let (&serial, rest) = <&U32>::split_from_message(message, rest)?; + let (&serial, rest) = <&Serial>::split_from_message(message, rest)?; let (&refresh, rest) = <&U32>::split_from_message(message, rest)?; let (&retry, rest) = <&U32>::split_from_message(message, rest)?; let (&expire, rest) = <&U32>::split_from_message(message, rest)?; @@ -192,7 +191,7 @@ impl<'a, N: SplitFrom<'a>> ParseFrom<'a> for Soa { fn parse_from(bytes: &'a [u8]) -> Result { let (mname, rest) = N::split_from(bytes)?; let (rname, rest) = N::split_from(rest)?; - let (&serial, rest) = <&U32>::split_from(rest)?; + let (&serial, rest) = <&Serial>::split_from(rest)?; let (&refresh, rest) = <&U32>::split_from(rest)?; let (&retry, rest) = <&U32>::split_from(rest)?; let (&expire, rest) = <&U32>::split_from(rest)?; From 7731a35d6be33d5aebdcceeb0bd844999691c725 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 14:29:26 +0100 Subject: [PATCH 027/167] [new_base] Add module 'build' --- src/new_base/build/mod.rs | 41 +++++++++++++++++++++++++++++++++++++++ src/new_base/mod.rs | 19 ++++++++++++------ src/new_base/parse/mod.rs | 2 +- 3 files changed, 55 insertions(+), 7 deletions(-) create mode 100644 src/new_base/build/mod.rs diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs new file mode 100644 index 000000000..e0a1c4925 --- /dev/null +++ b/src/new_base/build/mod.rs @@ -0,0 +1,41 @@ +//! Building DNS messages in the wire format. + +use core::fmt; + +//----------- Low-level building traits -------------------------------------- + +/// Building into a byte string. +pub trait BuildInto { + /// Append this value to the byte string. + /// + /// If the byte string is long enough to fit the message, the remaining + /// (unfilled) part of the byte string is returned. Otherwise, a + /// [`TruncationError`] is returned. + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError>; +} + +impl BuildInto for &T { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + (**self).build_into(bytes) + } +} + +//----------- TruncationError ------------------------------------------------ + +/// A DNS message did not fit in a buffer. +#[derive(Clone, Debug, PartialEq, Hash)] +pub struct TruncationError; + +//--- Formatting + +impl fmt::Display for TruncationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("A buffer was too small to fit a DNS message") + } +} diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 3becedb29..899225cf8 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -4,14 +4,11 @@ //! with DNS. Most importantly, it provides functionality for parsing and //! building DNS messages on the wire. +//--- DNS messages + mod message; pub use message::{Header, HeaderFlags, Message, SectionCounts}; -pub mod name; - -mod charstr; -pub use charstr::CharStr; - mod question; pub use question::{QClass, QType, Question, UnparsedQuestion}; @@ -20,7 +17,17 @@ pub use record::{ RClass, RType, Record, UnparsedRecord, UnparsedRecordData, TTL, }; -pub mod parse; +//--- Elements of DNS messages + +pub mod name; + +mod charstr; +pub use charstr::CharStr; mod serial; pub use serial::Serial; + +//--- Wire format + +pub mod build; +pub mod parse; diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index fba82d65c..022ff9df2 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -122,7 +122,7 @@ where //----------- ParseError ----------------------------------------------------- -/// A DNS parsing error. +/// A DNS message parsing error. #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct ParseError; From c27dd1f94a32635c3452a7f926196768c6c6a592 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 14:32:30 +0100 Subject: [PATCH 028/167] [new_base/build] Add a 'Builder' for DNS messages --- src/new_base/build/builder.rs | 353 ++++++++++++++++++++++++++++++++++ src/new_base/build/mod.rs | 27 +++ 2 files changed, 380 insertions(+) create mode 100644 src/new_base/build/builder.rs diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs new file mode 100644 index 000000000..7488f8f6a --- /dev/null +++ b/src/new_base/build/builder.rs @@ -0,0 +1,353 @@ +//! A builder for DNS messages. + +use core::{ + marker::PhantomData, + mem::ManuallyDrop, + ptr::{self, NonNull}, +}; + +use zerocopy::{FromBytes, IntoBytes, SizeError}; + +use crate::new_base::{name::RevName, Header, Message}; + +use super::TruncationError; + +//----------- Builder -------------------------------------------------------- + +/// A DNS message builder. +pub struct Builder<'b> { + /// The message being built. + /// + /// The message is divided into four parts: + /// + /// - The message header (borrowed mutably by this type). + /// - Committed message contents (borrowed *immutably* by this type). + /// - Appended message contents (borrowed mutably by this type). + /// - Uninitialized message contents (borrowed mutably by this type). + message: NonNull, + + _message: PhantomData<&'b mut Message>, + + /// Context for building. + context: &'b mut BuilderContext, + + /// The commit point of this builder. + /// + /// Message contents up to this point are committed and cannot be removed + /// by this builder. Message contents following this (up to the size in + /// the builder context) are appended but uncommitted. + commit: usize, +} + +//--- Initialization + +impl<'b> Builder<'b> { + /// Construct a [`Builder`] from raw parts. + /// + /// # Safety + /// + /// - `message` is a valid reference for the lifetime `'b`. + /// - `message.header` is mutably borrowed for `'b`. + /// - `message.contents[..commit]` is immutably borrowed for `'b`. + /// - `message.contents[commit..]` is mutably borrowed for `'b`. + /// + /// - `message` and `context` are paired together. + /// + /// - `commit` is at most `context.size()`, which is at most + /// `context.max_size()`. + pub unsafe fn from_raw_parts( + message: NonNull, + context: &'b mut BuilderContext, + commit: usize, + ) -> Self { + Self { + message, + _message: PhantomData, + context, + commit, + } + } + + /// Initialize an empty [`Builder`]. + /// + /// # Panics + /// + /// Panics if the buffer is less than 12 bytes long (which is the minimum + /// possible size for a DNS message). + pub fn new( + buffer: &'b mut [u8], + context: &'b mut BuilderContext, + ) -> Self { + assert!(buffer.len() >= 12); + let message = Message::mut_from_bytes(buffer) + .map_err(SizeError::from) + .expect("A 'Message' can fit in 12 bytes"); + context.size = 0; + context.max_size = message.contents.len(); + + // SAFETY: 'message' and 'context' are now consistent. + unsafe { Self::from_raw_parts(message.into(), context, 0) } + } +} + +//--- Inspection + +impl<'b> Builder<'b> { + /// The message header. + /// + /// The header can be modified by the builder, and so is only available + /// for a short lifetime. Note that it implements [`Copy`]. + pub fn header(&self) -> &Header { + // SAFETY: 'message.header' is mutably borrowed by 'self'. + unsafe { &(*self.message.as_ptr()).header } + } + + /// Mutable access to the message header. + pub fn header_mut(&mut self) -> &mut Header { + // SAFETY: 'message.header' is mutably borrowed by 'self'. + unsafe { &mut (*self.message.as_ptr()).header } + } + + /// Committed message contents. + /// + /// The message contents are available for the lifetime `'b`; the builder + /// cannot be used to modify them since they have been committed. + pub fn committed(&self) -> &'b [u8] { + // SAFETY: 'message.contents[..commit]' is immutably borrowed by + // 'self'. + unsafe { &(*self.message.as_ptr()).contents[..self.commit] } + } + + /// The appended but uncommitted contents of the message. + /// + /// The builder can modify or rewind these contents, so they are offered + /// with a short lifetime. + pub fn appended(&self) -> &[u8] { + // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. + let range = self.commit..self.context.size; + unsafe { &(*self.message.as_ptr()).contents[range] } + } + + /// The appended but uncommitted contents of the message, mutably. + pub fn appended_mut(&mut self) -> &mut [u8] { + // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. + let range = self.commit..self.context.size; + unsafe { &mut (*self.message.as_ptr()).contents[range] } + } + + /// Uninitialized space in the message buffer. + /// + /// This can be filled manually, then marked as initialized using + /// [`Self::mark_appended()`]. + pub fn uninitialized(&mut self) -> &mut [u8] { + // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. + let range = self.context.size..self.context.max_size; + unsafe { &mut (*self.message.as_ptr()).contents[range] } + } + + /// The message with all committed contents. + /// + /// The header of the message can be modified by the builder, so the + /// returned reference has a short lifetime. The message contents can be + /// borrowed for a longer lifetime -- see [`Self::committed()`]. + pub fn message(&self) -> &Message { + // SAFETY: All of 'message' can be immutably borrowed by 'self'. + let message = unsafe { &*self.message.as_ptr() }; + let message = message.as_bytes(); + Message::ref_from_bytes_with_elems(message, self.commit) + .map_err(SizeError::from) + .expect("'message' represents a valid 'Message'") + } + + /// The message including any uncommitted contents. + /// + /// The header of the message can be modified by the builder, so the + /// returned reference has a short lifetime. The message contents can be + /// borrowed for a longer lifetime -- see [`Self::committed()`]. + pub fn cur_message(&self) -> &Message { + // SAFETY: All of 'message' can be immutably borrowed by 'self'. + let message = unsafe { &*self.message.as_ptr() }; + let message = message.as_bytes(); + Message::ref_from_bytes_with_elems(message, self.context.size) + .map_err(SizeError::from) + .expect("'message' represents a valid 'Message'") + } + + /// The builder context. + pub fn context(&self) -> &BuilderContext { + &*self.context + } + + /// Decompose this builder into raw parts. + /// + /// This returns three components: + /// + /// - The message buffer. The committed contents of the message (the + /// first `commit` bytes of the message contents) are borrowed immutably + /// for the lifetime `'b`. The remainder of the message buffer is + /// borrowed mutably for the lifetime `'b`. + /// + /// - Context for this builder. + /// + /// - The amount of data committed in the message (`commit`). + /// + /// The builder can be recomposed with [`Self::from_raw_parts()`]. + pub fn into_raw_parts( + self, + ) -> (NonNull, &'b mut BuilderContext, usize) { + // NOTE: The context has to be moved out carefully. + let (message, commit) = (self.message, self.commit); + let this = ManuallyDrop::new(self); + let this = (&*this) as *const Self; + // SAFETY: 'this' is a valid object that can be moved out of. + let context = unsafe { ptr::read(ptr::addr_of!((*this).context)) }; + (message, context, commit) + } +} + +//--- Interaction + +impl<'b> Builder<'b> { + /// Rewind the builder, removing all committed content. + pub fn rewind(&mut self) { + self.context.size = self.commit; + } + + /// Commit all appended content. + pub fn commit(&mut self) { + self.commit = self.context.size; + } + + /// Mark bytes in the buffer as initialized. + /// + /// The given number of bytes from the beginning of + /// [`Self::uninitialized()`] will be marked as initialized, and will be + /// treated as appended content in the buffer. + /// + /// # Panics + /// + /// Panics if the uninitialized buffer is smaller than the given number of + /// initialized bytes. + pub fn mark_appended(&mut self, amount: usize) { + assert!(self.context.max_size - self.context.size >= amount); + self.context.size += amount; + } + + /// Delegate to a new builder. + /// + /// Any content committed by the builder will be added as uncommitted + /// content for this builder. + pub fn delegate(&mut self) -> Builder<'_> { + let commit = self.context.size; + unsafe { + Builder::from_raw_parts(self.message, &mut *self.context, commit) + } + } + + /// Limit the total message size. + /// + /// The message will not be allowed to exceed the given size, in bytes. + /// Only the message header and contents are counted; the enclosing UDP + /// or TCP packet size is not considered. If the message already exceeds + /// this size, a [`TruncationError`] is returned. + /// + /// This size will apply to all + pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { + if self.context.size <= size { + self.context.max_size = size; + Ok(()) + } else { + Err(TruncationError) + } + } + + /// Append data of a known size using a closure. + /// + /// All the requested bytes must be initialized. If not enough free space + /// could be obtained, a [`TruncationError`] is returned. + pub fn append_with( + &mut self, + size: usize, + fill: impl FnOnce(&mut [u8]), + ) -> Result<(), TruncationError> { + self.uninitialized() + .get_mut(..size) + .ok_or(TruncationError) + .map(fill) + } + + /// Append some bytes. + /// + /// No name compression will be performed. + pub fn append_bytes( + &mut self, + bytes: &[u8], + ) -> Result<(), TruncationError> { + self.append_with(bytes.len(), |buffer| buffer.copy_from_slice(bytes)) + } + + /// Compress and append a domain name. + pub fn append_name( + &mut self, + name: &RevName, + ) -> Result<(), TruncationError> { + // TODO: Perform name compression. + self.append_with(name.len(), |mut buffer| { + // Write out the labels in the name in reverse. + for label in name.labels() { + let label_buffer; + let offset = buffer.len() - label.len() - 1; + (buffer, label_buffer) = buffer.split_at_mut(offset); + label_buffer[0] = label.len() as u8; + label_buffer[1..].copy_from_slice(label.as_bytes()); + } + }) + } +} + +//--- Drop + +impl Drop for Builder<'_> { + fn drop(&mut self) { + // Drop uncommitted content. + self.rewind(); + } +} + +//----------- BuilderContext ------------------------------------------------- + +/// Context for building a DNS message. +#[derive(Clone, Debug)] +pub struct BuilderContext { + // TODO: Name compression. + /// The current size of the message contents. + size: usize, + + /// The maximum size of the message contents. + max_size: usize, +} + +//--- Inspection + +impl BuilderContext { + /// The size of the message contents. + pub fn size(&self) -> usize { + self.size + } + + /// The maximum size of the message contents. + pub fn max_size(&self) -> usize { + self.max_size + } +} + +//--- Default + +impl Default for BuilderContext { + fn default() -> Self { + Self { + size: 0, + max_size: 65535 - core::mem::size_of::
(), + } + } +} diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index e0a1c4925..4aec6820d 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -2,6 +2,33 @@ use core::fmt; +mod builder; +pub use builder::{Builder, BuilderContext}; + +//----------- Message-aware building traits ---------------------------------- + +/// Building into a DNS message. +pub trait BuildIntoMessage { + // Append this value to the DNS message. + /// + /// If the byte string is long enough to fit the message, it is appended + /// using the given message builder and committed. Otherwise, a + /// [`TruncationError`] is returned. + fn build_into_message( + &self, + builder: Builder<'_>, + ) -> Result<(), TruncationError>; +} + +impl BuildIntoMessage for &T { + fn build_into_message( + &self, + builder: Builder<'_>, + ) -> Result<(), TruncationError> { + (**self).build_into_message(builder) + } +} + //----------- Low-level building traits -------------------------------------- /// Building into a byte string. From 993eda7bf2941b870dc9c8bcc017c1b459e8f531 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 14:44:17 +0100 Subject: [PATCH 029/167] [new_base/name/reversed] Impl building traits --- src/new_base/build/builder.rs | 16 +++------ src/new_base/name/reversed.rs | 62 +++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 11 deletions(-) diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 7488f8f6a..93fda594b 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -10,7 +10,7 @@ use zerocopy::{FromBytes, IntoBytes, SizeError}; use crate::new_base::{name::RevName, Header, Message}; -use super::TruncationError; +use super::{BuildInto, TruncationError}; //----------- Builder -------------------------------------------------------- @@ -274,6 +274,7 @@ impl<'b> Builder<'b> { .get_mut(..size) .ok_or(TruncationError) .map(fill) + .map(|()| self.context.size += size) } /// Append some bytes. @@ -292,16 +293,9 @@ impl<'b> Builder<'b> { name: &RevName, ) -> Result<(), TruncationError> { // TODO: Perform name compression. - self.append_with(name.len(), |mut buffer| { - // Write out the labels in the name in reverse. - for label in name.labels() { - let label_buffer; - let offset = buffer.len() - label.len() - 1; - (buffer, label_buffer) = buffer.split_at_mut(offset); - label_buffer[0] = label.len() as u8; - label_buffer[1..].copy_from_slice(label.as_bytes()); - } - }) + name.build_into(self.uninitialized())?; + self.mark_appended(name.len()); + Ok(()) } } diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 60a640f91..513a72582 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -12,6 +12,7 @@ use zerocopy::IntoBytes; use zerocopy_derive::*; use crate::new_base::{ + build::{self, BuildInto, BuildIntoMessage, TruncationError}, parse::{ ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, }, @@ -97,6 +98,45 @@ impl RevName { } } +//--- Building into DNS messages + +impl BuildIntoMessage for RevName { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + builder.append_name(self)?; + builder.commit(); + Ok(()) + } +} + +//--- Building into byte strings + +impl BuildInto for RevName { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + if bytes.len() < self.len() { + return Err(TruncationError); + } + + let (mut buffer, rest) = bytes.split_at_mut(self.len()); + + // Write out the labels in the name in reverse. + for label in self.labels() { + let label_buffer; + let offset = buffer.len() - label.len() - 1; + (buffer, label_buffer) = buffer.split_at_mut(offset); + label_buffer[0] = label.len() as u8; + label_buffer[1..].copy_from_slice(label.as_bytes()); + } + + Ok(rest) + } +} + //--- Equality impl PartialEq for RevName { @@ -332,6 +372,17 @@ fn parse_segment<'a>( } } +//--- Building into DNS messages + +impl BuildIntoMessage for RevNameBuf { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + (**self).build_into_message(builder) + } +} + //--- Parsing from bytes impl<'a> SplitFrom<'a> for RevNameBuf { @@ -369,6 +420,17 @@ impl<'a> ParseFrom<'a> for RevNameBuf { } } +//--- Building into byte strings + +impl BuildInto for RevNameBuf { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + (**self).build_into(bytes) + } +} + //--- Interaction impl RevNameBuf { From d40749234bf3ea75a18cdffaf2d0132c313c6f4b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 15:53:54 +0100 Subject: [PATCH 030/167] [new_base/build] Add convenience impls for '[u8]' --- src/new_base/build/mod.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 4aec6820d..108cc76f0 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -29,6 +29,17 @@ impl BuildIntoMessage for &T { } } +impl BuildIntoMessage for [u8] { + fn build_into_message( + &self, + mut builder: Builder<'_>, + ) -> Result<(), TruncationError> { + builder.append_bytes(self)?; + builder.commit(); + Ok(()) + } +} + //----------- Low-level building traits -------------------------------------- /// Building into a byte string. @@ -53,6 +64,21 @@ impl BuildInto for &T { } } +impl BuildInto for [u8] { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + if self.len() <= bytes.len() { + let (bytes, rest) = bytes.split_at_mut(self.len()); + bytes.copy_from_slice(self); + Ok(rest) + } else { + Err(TruncationError) + } + } +} + //----------- TruncationError ------------------------------------------------ /// A DNS message did not fit in a buffer. From 55c7854e874ae82a36bbb67ace8ad1ab8dce823c Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 15:54:14 +0100 Subject: [PATCH 031/167] [new_base/question] Impl building traits --- src/new_base/question.rs | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 121eedff4..f173a664f 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -2,10 +2,11 @@ use core::ops::Range; -use zerocopy::network_endian::U16; +use zerocopy::{network_endian::U16, IntoBytes}; use zerocopy_derive::*; use super::{ + build::{self, BuildInto, BuildIntoMessage, TruncationError}, name::RevNameBuf, parse::{ ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, @@ -77,6 +78,23 @@ where } } +//--- Building into DNS messages + +impl BuildIntoMessage for Question +where + N: BuildIntoMessage, +{ + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.qname.build_into_message(builder.delegate())?; + builder.append_bytes(self.qtype.as_bytes())?; + builder.append_bytes(self.qclass.as_bytes())?; + Ok(()) + } +} + //--- Parsing from bytes impl<'a, N> SplitFrom<'a> for Question @@ -103,6 +121,23 @@ where } } +//--- Building into byte strings + +impl BuildInto for Question +where + N: BuildInto, +{ + fn build_into<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + bytes = self.qname.build_into(bytes)?; + bytes = self.qtype.as_bytes().build_into(bytes)?; + bytes = self.qclass.as_bytes().build_into(bytes)?; + Ok(bytes) + } +} + //----------- QType ---------------------------------------------------------- /// The type of a question. From 74e67836698385f4c7c994d5eea6f539b42cf743 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 15:54:24 +0100 Subject: [PATCH 032/167] [new_base/record] Support building and overhaul parsing --- src/new_base/record.rs | 179 ++++++++++++++++++++++++++++++++++------- 1 file changed, 149 insertions(+), 30 deletions(-) diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 42336dc6f..79687cece 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -7,11 +7,12 @@ use core::{ use zerocopy::{ network_endian::{U16, U32}, - FromBytes, IntoBytes, + FromBytes, IntoBytes, SizeError, }; use zerocopy_derive::*; use super::{ + build::{self, BuildInto, BuildIntoMessage, TruncationError}, name::RevNameBuf, parse::{ ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, @@ -64,19 +65,104 @@ impl Record { } } +//--- Parsing from DNS messages + +impl<'a, N, D> SplitFromMessage<'a> for Record +where + N: SplitFromMessage<'a>, + D: ParseFromMessage<'a>, +{ + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let (rname, rest) = N::split_from_message(message, start)?; + let (&rtype, rest) = <&RType>::split_from_message(message, rest)?; + let (&rclass, rest) = <&RClass>::split_from_message(message, rest)?; + let (&ttl, rest) = <&TTL>::split_from_message(message, rest)?; + let (&size, rest) = <&U16>::split_from_message(message, rest)?; + let size: usize = size.get().into(); + let rdata = if message.as_bytes().len() - rest >= size { + D::parse_from_message(message, rest..rest + size)? + } else { + return Err(ParseError); + }; + + Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest + size)) + } +} + +impl<'a, N, D> ParseFromMessage<'a> for Record +where + N: SplitFromMessage<'a>, + D: ParseFromMessage<'a>, +{ + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + let message = &message.as_bytes()[..range.end]; + let message = Message::ref_from_bytes(message) + .map_err(SizeError::from) + .expect("The input range ends past the message header"); + + let (this, rest) = Self::split_from_message(message, range.start)?; + + if rest == range.end { + Ok(this) + } else { + Err(ParseError) + } + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for Record +where + N: BuildIntoMessage, + D: BuildIntoMessage, +{ + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.rname.build_into_message(builder.delegate())?; + builder.append_bytes(self.rtype.as_bytes())?; + builder.append_bytes(self.rclass.as_bytes())?; + builder.append_bytes(self.ttl.as_bytes())?; + + // The offset of the record data size. + let offset = builder.appended().len(); + builder.append_bytes(&0u16.to_be_bytes())?; + self.rdata.build_into_message(builder.delegate())?; + let size = builder.appended().len() - 2 - offset; + let size = + u16::try_from(size).expect("the record data never exceeds 64KiB"); + builder.appended_mut()[offset..offset + 2] + .copy_from_slice(&size.to_be_bytes()); + + builder.commit(); + Ok(()) + } +} + //--- Parsing from bytes impl<'a, N, D> SplitFrom<'a> for Record where N: SplitFrom<'a>, - D: SplitFrom<'a>, + D: ParseFrom<'a>, { fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let (rname, rest) = N::split_from(bytes)?; let (rtype, rest) = RType::read_from_prefix(rest)?; let (rclass, rest) = RClass::read_from_prefix(rest)?; let (ttl, rest) = TTL::read_from_prefix(rest)?; - let (rdata, rest) = D::split_from(rest)?; + let (size, rest) = U16::read_from_prefix(rest)?; + let size: usize = size.get().into(); + let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; + let rdata = D::parse_from(rdata)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) } @@ -92,12 +178,44 @@ where let (rtype, rest) = RType::read_from_prefix(rest)?; let (rclass, rest) = RClass::read_from_prefix(rest)?; let (ttl, rest) = TTL::read_from_prefix(rest)?; - let rdata = D::parse_from(rest)?; + let (size, rest) = U16::read_from_prefix(rest)?; + let size: usize = size.get().into(); + let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; + let rdata = D::parse_from(rdata)?; Ok(Self::new(rname, rtype, rclass, ttl, rdata)) } } +//--- Building into byte strings + +impl BuildInto for Record +where + N: BuildInto, + D: BuildInto, +{ + fn build_into<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + bytes = self.rname.build_into(bytes)?; + bytes = self.rtype.as_bytes().build_into(bytes)?; + bytes = self.rclass.as_bytes().build_into(bytes)?; + bytes = self.ttl.as_bytes().build_into(bytes)?; + + let (size, bytes) = + ::mut_from_prefix(bytes).map_err(|_| TruncationError)?; + let bytes_len = bytes.len(); + + let rest = self.rdata.build_into(bytes)?; + *size = u16::try_from(bytes_len - rest.len()) + .expect("the record data never exceeds 64KiB") + .into(); + + Ok(rest) + } +} + //----------- RType ---------------------------------------------------------- /// The type of a record. @@ -194,18 +312,6 @@ impl UnparsedRecordData { //--- Parsing from DNS messages -impl<'a> SplitFromMessage<'a> for &'a UnparsedRecordData { - fn split_from_message( - message: &'a Message, - start: usize, - ) -> Result<(Self, usize), ParseError> { - let message = message.as_bytes(); - let bytes = message.get(start..).ok_or(ParseError)?; - let (this, rest) = Self::split_from(bytes)?; - Ok((this, message.len() - rest.len())) - } -} - impl<'a> ParseFromMessage<'a> for &'a UnparsedRecordData { fn parse_from_message( message: &'a Message, @@ -217,26 +323,39 @@ impl<'a> ParseFromMessage<'a> for &'a UnparsedRecordData { } } -//--- Parsing from bytes +//--- Building into DNS messages -impl<'a> SplitFrom<'a> for &'a UnparsedRecordData { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - let (size, rest) = U16::read_from_prefix(bytes)?; - let size = size.get() as usize; - let (data, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; - // SAFETY: 'data.len() == size' which is a 'u16'. - let this = unsafe { UnparsedRecordData::new_unchecked(data) }; - Ok((this, rest)) +impl BuildIntoMessage for UnparsedRecordData { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.0.build_into_message(builder) } } +//--- Parsing from bytes + impl<'a> ParseFrom<'a> for &'a UnparsedRecordData { fn parse_from(bytes: &'a [u8]) -> Result { - let (size, rest) = U16::read_from_prefix(bytes)?; - let size = size.get() as usize; - let data = <[u8]>::ref_from_bytes_with_elems(rest, size)?; - // SAFETY: 'data.len() == size' which is a 'u16'. - Ok(unsafe { UnparsedRecordData::new_unchecked(data) }) + if bytes.len() > 65535 { + // Too big to fit in an 'UnparsedRecordData'. + return Err(ParseError); + } + + // SAFETY: 'bytes.len()' fits within a 'u16'. + Ok(unsafe { UnparsedRecordData::new_unchecked(bytes) }) + } +} + +//--- Building into byte strings + +impl BuildInto for UnparsedRecordData { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.0.build_into(bytes) } } From 6772427ab1b31b9d7d05e5817aac47aba7ac5621 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 16:17:58 +0100 Subject: [PATCH 033/167] [new_base/charstr] Support building --- src/new_base/charstr.rs | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index dfe2bc8f9..fdd5e5bdf 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -6,6 +6,7 @@ use zerocopy::IntoBytes; use zerocopy_derive::*; use super::{ + build::{self, BuildInto, BuildIntoMessage, TruncationError}, parse::{ ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, }, @@ -48,6 +49,20 @@ impl<'a> ParseFromMessage<'a> for &'a CharStr { } } +//--- Building into DNS messages + +impl BuildIntoMessage for CharStr { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + builder.append_bytes(&[self.octets.len() as u8])?; + builder.append_bytes(&self.octets)?; + builder.commit(); + Ok(()) + } +} + //--- Parsing from bytes impl<'a> SplitFrom<'a> for &'a CharStr { @@ -75,6 +90,20 @@ impl<'a> ParseFrom<'a> for &'a CharStr { } } +//--- Building into byte strings + +impl BuildInto for CharStr { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + let (length, bytes) = + bytes.split_first_mut().ok_or(TruncationError)?; + *length = self.octets.len() as u8; + self.octets.build_into(bytes) + } +} + //--- Equality impl PartialEq for CharStr { From 59c33b2c5b2409ff7642de74ac9810e2a185ca41 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 16:18:08 +0100 Subject: [PATCH 034/167] [new_rdata/rfc1035] Impl building traits --- src/new_rdata/rfc1035.rs | 224 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 224 insertions(+) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index 4e09c9727..dc25f0007 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -9,6 +9,7 @@ use zerocopy::{ use zerocopy_derive::*; use crate::new_base::{ + build::{self, BuildInto, BuildIntoMessage, TruncationError}, parse::{ ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, }, @@ -73,6 +74,28 @@ impl fmt::Display for A { } } +//--- Building into DNS messages + +impl BuildIntoMessage for A { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.as_bytes().build_into_message(builder) + } +} + +//--- Building into byte strings + +impl BuildInto for A { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.as_bytes().build_into(bytes) + } +} + //----------- Ns ------------------------------------------------------------- /// The authoritative name server for this domain. @@ -94,6 +117,17 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ns { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Ns { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.name.build_into_message(builder) + } +} + //--- Parsing from bytes impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ns { @@ -102,6 +136,17 @@ impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ns { } } +//--- Building into bytes + +impl BuildInto for Ns { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.name.build_into(bytes) + } +} + //----------- Cname ---------------------------------------------------------- /// The canonical name for this domain. @@ -123,6 +168,17 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Cname { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Cname { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.name.build_into_message(builder) + } +} + //--- Parsing from bytes impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Cname { @@ -131,6 +187,17 @@ impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Cname { } } +//--- Building into bytes + +impl BuildInto for Cname { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.name.build_into(bytes) + } +} + //----------- Soa ------------------------------------------------------------ /// The start of a zone of authority. @@ -185,6 +252,25 @@ impl<'a, N: SplitFromMessage<'a>> ParseFromMessage<'a> for Soa { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Soa { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.mname.build_into_message(builder.delegate())?; + self.rname.build_into_message(builder.delegate())?; + builder.append_bytes(self.serial.as_bytes())?; + builder.append_bytes(self.refresh.as_bytes())?; + builder.append_bytes(self.retry.as_bytes())?; + builder.append_bytes(self.expire.as_bytes())?; + builder.append_bytes(self.minimum.as_bytes())?; + builder.commit(); + Ok(()) + } +} + //--- Parsing from bytes impl<'a, N: SplitFrom<'a>> ParseFrom<'a> for Soa { @@ -209,6 +295,24 @@ impl<'a, N: SplitFrom<'a>> ParseFrom<'a> for Soa { } } +//--- Building into byte strings + +impl BuildInto for Soa { + fn build_into<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + bytes = self.mname.build_into(bytes)?; + bytes = self.rname.build_into(bytes)?; + bytes = self.serial.as_bytes().build_into(bytes)?; + bytes = self.refresh.as_bytes().build_into(bytes)?; + bytes = self.retry.as_bytes().build_into(bytes)?; + bytes = self.expire.as_bytes().build_into(bytes)?; + bytes = self.minimum.as_bytes().build_into(bytes)?; + Ok(bytes) + } +} + //----------- Wks ------------------------------------------------------------ /// Well-known services supported on this domain. @@ -253,6 +357,28 @@ impl fmt::Debug for Wks { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Wks { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.as_bytes().build_into_message(builder) + } +} + +//--- Building into byte strings + +impl BuildInto for Wks { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.as_bytes().build_into(bytes) + } +} + //----------- Ptr ------------------------------------------------------------ /// A pointer to another domain name. @@ -274,6 +400,17 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ptr { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Ptr { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.name.build_into_message(builder) + } +} + //--- Parsing from bytes impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ptr { @@ -282,6 +419,17 @@ impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ptr { } } +//--- Building into bytes + +impl BuildInto for Ptr { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.name.build_into(bytes) + } +} + //----------- Hinfo ---------------------------------------------------------- /// Information about the host computer. @@ -309,6 +457,20 @@ impl<'a> ParseFromMessage<'a> for Hinfo<'a> { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Hinfo<'_> { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.cpu.build_into_message(builder.delegate())?; + self.os.build_into_message(builder.delegate())?; + builder.commit(); + Ok(()) + } +} + //--- Parsing from bytes impl<'a> ParseFrom<'a> for Hinfo<'a> { @@ -319,6 +481,19 @@ impl<'a> ParseFrom<'a> for Hinfo<'a> { } } +//--- Building into bytes + +impl BuildInto for Hinfo<'_> { + fn build_into<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + bytes = self.cpu.build_into(bytes)?; + bytes = self.os.build_into(bytes)?; + Ok(bytes) + } +} + //----------- Mx ------------------------------------------------------------- /// A host that can exchange mail for this domain. @@ -349,6 +524,20 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Mx { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Mx { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + builder.append_bytes(self.preference.as_bytes())?; + self.exchange.build_into_message(builder.delegate())?; + builder.commit(); + Ok(()) + } +} + //--- Parsing from bytes impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Mx { @@ -362,6 +551,19 @@ impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Mx { } } +//--- Building into byte strings + +impl BuildInto for Mx { + fn build_into<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + bytes = self.preference.as_bytes().build_into(bytes)?; + bytes = self.exchange.build_into(bytes)?; + Ok(bytes) + } +} + //----------- Txt ------------------------------------------------------------ /// Free-form text strings about this domain. @@ -389,6 +591,17 @@ impl<'a> ParseFromMessage<'a> for &'a Txt { } } +//--- Building into DNS messages + +impl BuildIntoMessage for Txt { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.content.build_into_message(builder) + } +} + //--- Parsing from bytes impl<'a> ParseFrom<'a> for &'a Txt { @@ -403,3 +616,14 @@ impl<'a> ParseFrom<'a> for &'a Txt { Ok(unsafe { core::mem::transmute::<&'a [u8], Self>(bytes) }) } } + +//--- Building into byte strings + +impl BuildInto for Txt { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.content.build_into(bytes) + } +} From 2d845732dd74d34f76f02462cd3f7b9f2c61962a Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 26 Dec 2024 16:57:13 +0100 Subject: [PATCH 035/167] [build/builder] Improve unclear documentation --- src/new_base/build/builder.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 93fda594b..da6db0ba4 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -70,6 +70,9 @@ impl<'b> Builder<'b> { /// Initialize an empty [`Builder`]. /// + /// The message header is left uninitialized. Use [`Self::header_mut()`] + /// to initialize it. + /// /// # Panics /// /// Panics if the buffer is less than 12 bytes long (which is the minimum @@ -251,10 +254,17 @@ impl<'b> Builder<'b> { /// or TCP packet size is not considered. If the message already exceeds /// this size, a [`TruncationError`] is returned. /// - /// This size will apply to all + /// This size will apply to all builders for this message (including those + /// that delegated to `self`). It will not be automatically revoked if + /// message building fails. + /// + /// # Panics + /// + /// Panics if the given size is less than 12 bytes. pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { - if self.context.size <= size { - self.context.max_size = size; + assert!(size >= 12); + if self.context.size <= size - 12 { + self.context.max_size = size - 12; Ok(()) } else { Err(TruncationError) From a8433c91bc57ec5fff793ed8ee996cd8b778a802 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 27 Dec 2024 12:46:32 +0100 Subject: [PATCH 036/167] [new_rdata/rfc1035] gate 'Ipv4Addr' behind 'std' --- src/new_rdata/rfc1035.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index dc25f0007..6bc8eead2 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -1,6 +1,9 @@ //! Core record data types. -use core::{fmt, net::Ipv4Addr, ops::Range, str::FromStr}; +use core::{fmt, ops::Range, str::FromStr}; + +#[cfg(feature = "std")] +use std::net::Ipv4Addr; use zerocopy::{ network_endian::{U16, U32}, @@ -42,6 +45,7 @@ pub struct A { //--- Converting to and from 'Ipv4Addr' +#[cfg(feature = "std")] impl From for A { fn from(value: Ipv4Addr) -> Self { Self { @@ -50,6 +54,7 @@ impl From for A { } } +#[cfg(feature = "std")] impl From for Ipv4Addr { fn from(value: A) -> Self { Self::from(value.octets) @@ -58,6 +63,7 @@ impl From for Ipv4Addr { //--- Parsing from a string +#[cfg(feature = "std")] impl FromStr for A { type Err = ::Err; @@ -68,6 +74,7 @@ impl FromStr for A { //--- Formatting +#[cfg(feature = "std")] impl fmt::Display for A { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { Ipv4Addr::from(*self).fmt(f) From 86f14bb7a81477baf954e98a0f5324be572ff130 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 27 Dec 2024 12:50:53 +0100 Subject: [PATCH 037/167] [new_rdata/rfc1035] Gate more things under 'std' --- src/new_rdata/rfc1035.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index 6bc8eead2..a05d3cb97 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -1,6 +1,8 @@ //! Core record data types. -use core::{fmt, ops::Range, str::FromStr}; +use core::ops::Range; +#[cfg(feature = "std")] +use core::{fmt, str::FromStr}; #[cfg(feature = "std")] use std::net::Ipv4Addr; @@ -338,6 +340,7 @@ pub struct Wks { //--- Formatting +#[cfg(feature = "std")] impl fmt::Debug for Wks { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { struct Ports<'a>(&'a [u8]); From 072827597b2dba6edd85d4a3478e128d6f9da4ac Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 27 Dec 2024 12:55:10 +0100 Subject: [PATCH 038/167] [new_base/build/builder] Remove unnecessary explicit lifetime in impl --- src/new_base/build/builder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index da6db0ba4..75a9cfc69 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -210,7 +210,7 @@ impl<'b> Builder<'b> { //--- Interaction -impl<'b> Builder<'b> { +impl Builder<'_> { /// Rewind the builder, removing all committed content. pub fn rewind(&mut self) { self.context.size = self.commit; From f73ca63e1be6832277ee39cbb3c8539a24e316c9 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Sat, 28 Dec 2024 09:51:16 +0100 Subject: [PATCH 039/167] [new_rdata/rfc1035] Support 'Display' outside 'std' --- src/new_rdata/rfc1035.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/rfc1035.rs index a05d3cb97..f42cb3180 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/rfc1035.rs @@ -1,8 +1,9 @@ //! Core record data types. -use core::ops::Range; +use core::{fmt, ops::Range}; + #[cfg(feature = "std")] -use core::{fmt, str::FromStr}; +use core::str::FromStr; #[cfg(feature = "std")] use std::net::Ipv4Addr; @@ -76,10 +77,10 @@ impl FromStr for A { //--- Formatting -#[cfg(feature = "std")] impl fmt::Display for A { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - Ipv4Addr::from(*self).fmt(f) + let [a, b, c, d] = self.octets; + write!(f, "{a}.{b}.{c}.{d}") } } @@ -340,7 +341,6 @@ pub struct Wks { //--- Formatting -#[cfg(feature = "std")] impl fmt::Debug for Wks { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { struct Ports<'a>(&'a [u8]); @@ -360,7 +360,7 @@ impl fmt::Debug for Wks { } f.debug_struct("Wks") - .field("address", &Ipv4Addr::from(self.address)) + .field("address", &format_args!("{}", self.address)) .field("protocol", &self.protocol) .field("ports", &Ports(&self.ports)) .finish() From 1133b4c0860cef74a095dfa10caaaa6a1e49fcb7 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 14:08:55 +0100 Subject: [PATCH 040/167] [new_rdata] Inline 'rfc1035' and support 'rfc3596' --- src/new_rdata/mod.rs | 6 ++- src/new_rdata/rfc3596.rs | 98 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 1 deletion(-) create mode 100644 src/new_rdata/rfc3596.rs diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 54afb39ee..248c02d37 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -1,3 +1,7 @@ //! Record data types. -pub mod rfc1035; +mod rfc1035; +pub use rfc1035::{Cname, Hinfo, Mx, Ns, Ptr, Soa, Txt, Wks, A}; + +mod rfc3596; +pub use rfc3596::Aaaa; diff --git a/src/new_rdata/rfc3596.rs b/src/new_rdata/rfc3596.rs new file mode 100644 index 000000000..9a474aab1 --- /dev/null +++ b/src/new_rdata/rfc3596.rs @@ -0,0 +1,98 @@ +//! IPv6 record data types. + +#[cfg(feature = "std")] +use core::{fmt, str::FromStr}; + +#[cfg(feature = "std")] +use std::net::Ipv6Addr; + +use zerocopy::IntoBytes; +use zerocopy_derive::*; + +use crate::new_base::build::{ + self, BuildInto, BuildIntoMessage, TruncationError, +}; + +//----------- Aaaa ----------------------------------------------------------- + +/// The IPv6 address of a host responsible for this domain. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct Aaaa { + /// The IPv6 address octets. + pub octets: [u8; 16], +} + +//--- Converting to and from 'Ipv6Addr' + +#[cfg(feature = "std")] +impl From for Aaaa { + fn from(value: Ipv6Addr) -> Self { + Self { + octets: value.octets(), + } + } +} + +#[cfg(feature = "std")] +impl From for Ipv6Addr { + fn from(value: Aaaa) -> Self { + Self::from(value.octets) + } +} + +//--- Parsing from a string + +#[cfg(feature = "std")] +impl FromStr for Aaaa { + type Err = ::Err; + + fn from_str(s: &str) -> Result { + Ipv6Addr::from_str(s).map(Aaaa::from) + } +} + +//--- Formatting + +#[cfg(feature = "std")] +impl fmt::Display for Aaaa { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Ipv6Addr::from(*self).fmt(f) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for Aaaa { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.as_bytes().build_into_message(builder) + } +} + +//--- Building into byte strings + +impl BuildInto for Aaaa { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.as_bytes().build_into(bytes) + } +} From 1a95ae811cac3264d8ba3e7ece2f814e37da80de Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 14:10:51 +0100 Subject: [PATCH 041/167] [new_rdata] Rename submodules with more intuitive names --- src/new_rdata/{rfc1035.rs => basic.rs} | 2 ++ src/new_rdata/{rfc3596.rs => ipv6.rs} | 2 ++ src/new_rdata/mod.rs | 8 ++++---- 3 files changed, 8 insertions(+), 4 deletions(-) rename src/new_rdata/{rfc1035.rs => basic.rs} (99%) rename src/new_rdata/{rfc3596.rs => ipv6.rs} (96%) diff --git a/src/new_rdata/rfc1035.rs b/src/new_rdata/basic.rs similarity index 99% rename from src/new_rdata/rfc1035.rs rename to src/new_rdata/basic.rs index f42cb3180..2abb38cf4 100644 --- a/src/new_rdata/rfc1035.rs +++ b/src/new_rdata/basic.rs @@ -1,4 +1,6 @@ //! Core record data types. +//! +//! See [RFC 1035](https://datatracker.ietf.org/doc/html/rfc1035). use core::{fmt, ops::Range}; diff --git a/src/new_rdata/rfc3596.rs b/src/new_rdata/ipv6.rs similarity index 96% rename from src/new_rdata/rfc3596.rs rename to src/new_rdata/ipv6.rs index 9a474aab1..606486d08 100644 --- a/src/new_rdata/rfc3596.rs +++ b/src/new_rdata/ipv6.rs @@ -1,4 +1,6 @@ //! IPv6 record data types. +//! +//! See [RFC 3596](https://datatracker.ietf.org/doc/html/rfc3596). #[cfg(feature = "std")] use core::{fmt, str::FromStr}; diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 248c02d37..bbec94d3d 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -1,7 +1,7 @@ //! Record data types. -mod rfc1035; -pub use rfc1035::{Cname, Hinfo, Mx, Ns, Ptr, Soa, Txt, Wks, A}; +mod basic; +pub use basic::{Cname, Hinfo, Mx, Ns, Ptr, Soa, Txt, Wks, A}; -mod rfc3596; -pub use rfc3596::Aaaa; +mod ipv6; +pub use ipv6::Aaaa; From 6bf26c5dd538ff0f46e13bbfafc5c1a27ab2abee Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 15:20:54 +0100 Subject: [PATCH 042/167] [new_base] Set up basic EDNS support Instead of 'new_base::opt', EDNS is now granted its own top-level module. This matches up well with 'crate::tsig'. --- src/lib.rs | 6 +- src/new_base/message.rs | 2 +- src/new_base/record.rs | 8 +- src/new_edns/mod.rs | 189 ++++++++++++++++++++++++++++++++++++++++ src/new_rdata/edns.rs | 55 ++++++++++++ src/new_rdata/mod.rs | 3 + 6 files changed, 257 insertions(+), 6 deletions(-) create mode 100644 src/new_edns/mod.rs create mode 100644 src/new_rdata/edns.rs diff --git a/src/lib.rs b/src/lib.rs index b2f7ac66c..4d08972f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -193,8 +193,6 @@ extern crate core; pub mod base; pub mod dep; pub mod net; -pub mod new_base; -pub mod new_rdata; pub mod rdata; pub mod resolv; pub mod sign; @@ -205,3 +203,7 @@ pub mod validate; pub mod validator; pub mod zonefile; pub mod zonetree; + +pub mod new_base; +pub mod new_edns; +pub mod new_rdata; diff --git a/src/new_base/message.rs b/src/new_base/message.rs index c07d605fa..e60ae76ff 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -186,7 +186,7 @@ impl fmt::Debug for HeaderFlags { .field("should_recurse (rd)", &self.should_recurse()) .field("can_recurse (ra)", &self.can_recurse()) .field("rcode", &self.rcode()) - .field("bits", &self.inner.get()) + .field("bits", &self.bits()) .finish() } } diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 79687cece..9522d80d3 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -317,9 +317,11 @@ impl<'a> ParseFromMessage<'a> for &'a UnparsedRecordData { message: &'a Message, range: Range, ) -> Result { - let message = message.as_bytes(); - let bytes = message.get(range).ok_or(ParseError)?; - Self::parse_from(bytes) + message + .as_bytes() + .get(range) + .ok_or(ParseError) + .and_then(Self::parse_from) } } diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs new file mode 100644 index 000000000..014ac7104 --- /dev/null +++ b/src/new_edns/mod.rs @@ -0,0 +1,189 @@ +//! Support for Extended DNS (RFC 6891). +//! +//! See [RFC 6891](https://datatracker.ietf.org/doc/html/rfc6891). + +use core::fmt; + +use zerocopy::{network_endian::U16, FromBytes, IntoBytes}; +use zerocopy_derive::*; + +use crate::{ + new_base::{ + parse::{ParseError, SplitFrom, SplitFromMessage}, + Message, + }, + new_rdata::Opt, +}; + +//----------- EdnsRecord ----------------------------------------------------- + +/// An Extended DNS record. +#[derive(Clone)] +pub struct EdnsRecord<'a> { + /// The largest UDP payload the DNS client supports, in bytes. + pub max_udp_payload: U16, + + /// An extension to the response code of the DNS message. + pub ext_rcode: u8, + + /// The Extended DNS version used by this message. + pub version: u8, + + /// Flags describing the message. + pub flags: EdnsFlags, + + /// Extended DNS options. + pub options: &'a Opt, +} + +//--- Parsing from DNS messages + +impl<'a> SplitFromMessage<'a> for EdnsRecord<'a> { + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let bytes = message.as_bytes().get(start..).ok_or(ParseError)?; + let (this, rest) = Self::split_from(bytes)?; + Ok((this, message.as_bytes().len() - rest.len())) + } +} + +//--- Parsing from bytes + +impl<'a> SplitFrom<'a> for EdnsRecord<'a> { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + // Strip the record name (root) and the record type. + let rest = bytes.strip_prefix(&[0, 0, 41]).ok_or(ParseError)?; + + let (&max_udp_payload, rest) = <&U16>::split_from(rest)?; + let (&ext_rcode, rest) = <&u8>::split_from(rest)?; + let (&version, rest) = <&u8>::split_from(rest)?; + let (&flags, rest) = <&EdnsFlags>::split_from(rest)?; + + // Split the record size and data. + let (&size, rest) = <&U16>::split_from(rest)?; + let size: usize = size.get().into(); + let (options, rest) = Opt::ref_from_prefix_with_elems(rest, size)?; + + Ok(( + Self { + max_udp_payload, + ext_rcode, + version, + flags, + options, + }, + rest, + )) + } +} + +//----------- EdnsFlags ------------------------------------------------------ + +/// Extended DNS flags describing a message. +#[derive( + Copy, + Clone, + Default, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct EdnsFlags { + inner: U16, +} + +//--- Interaction + +impl EdnsFlags { + /// Get the specified flag bit. + fn get_flag(&self, pos: u32) -> bool { + self.inner.get() & (1 << pos) != 0 + } + + /// Set the specified flag bit. + fn set_flag(mut self, pos: u32, value: bool) -> Self { + self.inner &= !(1 << pos); + self.inner |= (value as u16) << pos; + self + } + + /// The raw flags bits. + pub fn bits(&self) -> u16 { + self.inner.get() + } + + /// Whether the client supports DNSSEC. + /// + /// See [RFC 3225](https://datatracker.ietf.org/doc/html/rfc3225). + pub fn is_dnssec_ok(&self) -> bool { + self.get_flag(15) + } + + /// Indicate support for DNSSEC to the server. + /// + /// See [RFC 3225](https://datatracker.ietf.org/doc/html/rfc3225). + pub fn set_dnssec_ok(self, value: bool) -> Self { + self.set_flag(15, value) + } +} + +//--- Formatting + +impl fmt::Debug for EdnsFlags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("EdnsFlags") + .field("dnssec_ok (do)", &self.is_dnssec_ok()) + .field("bits", &self.bits()) + .finish() + } +} + +//----------- EdnsOption ----------------------------------------------------- + +/// An Extended DNS option. +#[derive(Debug)] +#[non_exhaustive] +pub enum EdnsOption<'b> { + /// An unknown option. + Unknown(OptionCode, &'b UnknownOption), +} + +//----------- OptionCode ----------------------------------------------------- + +/// An Extended DNS option code. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(transparent)] +pub struct OptionCode { + /// The option code. + pub code: U16, +} + +//----------- UnknownOption -------------------------------------------------- + +/// Data for an unknown Extended DNS option. +#[derive(Debug, FromBytes, IntoBytes, Immutable, Unaligned)] +#[repr(transparent)] +pub struct UnknownOption { + /// The unparsed option data. + pub octets: [u8], +} diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs new file mode 100644 index 000000000..89e146062 --- /dev/null +++ b/src/new_rdata/edns.rs @@ -0,0 +1,55 @@ +//! Record data types for Extended DNS. +//! +//! See [RFC 6891](https://datatracker.ietf.org/doc/html/rfc6891). + +use zerocopy_derive::*; + +use crate::new_base::build::{ + self, BuildInto, BuildIntoMessage, TruncationError, +}; + +//----------- Opt ------------------------------------------------------------ + +/// Extended DNS options. +#[derive( + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + FromBytes, + IntoBytes, + KnownLayout, + Immutable, + Unaligned, +)] +#[repr(C)] // 'derive(KnownLayout)' doesn't work with 'repr(transparent)'. +pub struct Opt { + /// The raw serialized options. + contents: [u8], +} + +// TODO: Parsing the EDNS options. +// TODO: Formatting. + +//--- Building into DNS messages + +impl BuildIntoMessage for Opt { + fn build_into_message( + &self, + builder: build::Builder<'_>, + ) -> Result<(), TruncationError> { + self.contents.build_into_message(builder) + } +} + +//--- Building into byte strings + +impl BuildInto for Opt { + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.contents.build_into(bytes) + } +} diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index bbec94d3d..1aad4cca0 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -5,3 +5,6 @@ pub use basic::{Cname, Hinfo, Mx, Ns, Ptr, Soa, Txt, Wks, A}; mod ipv6; pub use ipv6::Aaaa; + +mod edns; +pub use edns::Opt; From aa0c59036ea53804ade672b6eeb6b8c9a0c9231e Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 15:45:05 +0100 Subject: [PATCH 043/167] [new_base/record] Add trait 'ParseRecordData' --- src/new_base/mod.rs | 3 +- src/new_base/record.rs | 73 +++++++++++++++++++++++++----------------- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 899225cf8..3c2e34068 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -14,7 +14,8 @@ pub use question::{QClass, QType, Question, UnparsedQuestion}; mod record; pub use record::{ - RClass, RType, Record, UnparsedRecord, UnparsedRecordData, TTL, + ParseRecordData, RClass, RType, Record, UnparsedRecord, + UnparsedRecordData, TTL, }; //--- Elements of DNS messages diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 9522d80d3..6f3a1daa8 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -70,7 +70,7 @@ impl Record { impl<'a, N, D> SplitFromMessage<'a> for Record where N: SplitFromMessage<'a>, - D: ParseFromMessage<'a>, + D: ParseRecordData<'a>, { fn split_from_message( message: &'a Message, @@ -83,7 +83,7 @@ where let (&size, rest) = <&U16>::split_from_message(message, rest)?; let size: usize = size.get().into(); let rdata = if message.as_bytes().len() - rest >= size { - D::parse_from_message(message, rest..rest + size)? + D::parse_record_data(message, rest..rest + size, rtype)? } else { return Err(ParseError); }; @@ -95,7 +95,7 @@ where impl<'a, N, D> ParseFromMessage<'a> for Record where N: SplitFromMessage<'a>, - D: ParseFromMessage<'a>, + D: ParseRecordData<'a>, { fn parse_from_message( message: &'a Message, @@ -152,7 +152,7 @@ where impl<'a, N, D> SplitFrom<'a> for Record where N: SplitFrom<'a>, - D: ParseFrom<'a>, + D: ParseRecordData<'a>, { fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let (rname, rest) = N::split_from(bytes)?; @@ -162,7 +162,7 @@ where let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; - let rdata = D::parse_from(rdata)?; + let rdata = D::parse_record_data_bytes(rdata, rtype)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) } @@ -171,7 +171,7 @@ where impl<'a, N, D> ParseFrom<'a> for Record where N: SplitFrom<'a>, - D: ParseFrom<'a>, + D: ParseRecordData<'a>, { fn parse_from(bytes: &'a [u8]) -> Result { let (rname, rest) = N::split_from(bytes)?; @@ -181,7 +181,7 @@ where let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; - let rdata = D::parse_from(rdata)?; + let rdata = D::parse_record_data_bytes(rdata, rtype)?; Ok(Self::new(rname, rtype, rclass, ttl, rdata)) } @@ -288,6 +288,24 @@ pub struct TTL { pub value: U32, } +//----------- ParseRecordData ------------------------------------------------ + +/// Parsing DNS record data. +pub trait ParseRecordData<'a>: Sized { + /// Parse DNS record data of the given type from a DNS message. + fn parse_record_data( + message: &'a Message, + range: Range, + rtype: RType, + ) -> Result; + + /// Parse DNS record data of the given type from a byte string. + fn parse_record_data_bytes( + bytes: &'a [u8], + rtype: RType, + ) -> Result; +} + //----------- UnparsedRecordData --------------------------------------------- /// Unparsed DNS record data. @@ -310,18 +328,29 @@ impl UnparsedRecordData { } } -//--- Parsing from DNS messages +//--- Parsing record data -impl<'a> ParseFromMessage<'a> for &'a UnparsedRecordData { - fn parse_from_message( +impl<'a> ParseRecordData<'a> for &'a UnparsedRecordData { + fn parse_record_data( message: &'a Message, range: Range, + rtype: RType, + ) -> Result { + let bytes = message.as_bytes().get(range).ok_or(ParseError)?; + Self::parse_record_data_bytes(bytes, rtype) + } + + fn parse_record_data_bytes( + bytes: &'a [u8], + _rtype: RType, ) -> Result { - message - .as_bytes() - .get(range) - .ok_or(ParseError) - .and_then(Self::parse_from) + if bytes.len() > 65535 { + // Too big to fit in an 'UnparsedRecordData'. + return Err(ParseError); + } + + // SAFETY: 'bytes.len()' fits within a 'u16'. + Ok(unsafe { UnparsedRecordData::new_unchecked(bytes) }) } } @@ -336,20 +365,6 @@ impl BuildIntoMessage for UnparsedRecordData { } } -//--- Parsing from bytes - -impl<'a> ParseFrom<'a> for &'a UnparsedRecordData { - fn parse_from(bytes: &'a [u8]) -> Result { - if bytes.len() > 65535 { - // Too big to fit in an 'UnparsedRecordData'. - return Err(ParseError); - } - - // SAFETY: 'bytes.len()' fits within a 'u16'. - Ok(unsafe { UnparsedRecordData::new_unchecked(bytes) }) - } -} - //--- Building into byte strings impl BuildInto for UnparsedRecordData { From d11f9e577d41fe1bf70a993337b46f6f8a16a155 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 23:50:22 +0100 Subject: [PATCH 044/167] [new_base/parse] Make 'Split*' imply 'Parse*' --- src/new_base/name/label.rs | 10 +++++++- src/new_base/parse/mod.rs | 4 ++-- src/new_edns/mod.rs | 49 ++++++++++++++++++++++++++++++++++---- 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index b93b32f80..087692047 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -9,7 +9,7 @@ use core::{ use zerocopy_derive::*; -use crate::new_base::parse::{ParseError, SplitFrom}; +use crate::new_base::parse::{ParseError, ParseFrom, SplitFrom}; //----------- Label ---------------------------------------------------------- @@ -65,6 +65,14 @@ impl<'a> SplitFrom<'a> for &'a Label { } } +impl<'a> ParseFrom<'a> for &'a Label { + fn parse_from(bytes: &'a [u8]) -> Result { + Self::split_from(bytes).and_then(|(this, rest)| { + rest.is_empty().then_some(this).ok_or(ParseError) + }) + } +} + //--- Inspection impl Label { diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 022ff9df2..31ad191ba 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -18,7 +18,7 @@ use super::Message; //----------- Message-aware parsing traits ----------------------------------- /// A type that can be parsed from a DNS message. -pub trait SplitFromMessage<'a>: Sized { +pub trait SplitFromMessage<'a>: Sized + ParseFromMessage<'a> { /// Parse a value of [`Self`] from the start of a byte string within a /// particular DNS message. /// @@ -80,7 +80,7 @@ where //----------- Low-level parsing traits --------------------------------------- /// Parsing from the start of a byte string. -pub trait SplitFrom<'a>: Sized { +pub trait SplitFrom<'a>: Sized + ParseFrom<'a> { /// Parse a value of [`Self`] from the start of the byte string. /// /// If parsing is successful, the parsed value and the rest of the string diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 014ac7104..f4d12f9c1 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -2,14 +2,17 @@ //! //! See [RFC 6891](https://datatracker.ietf.org/doc/html/rfc6891). -use core::fmt; +use core::{fmt, ops::Range}; use zerocopy::{network_endian::U16, FromBytes, IntoBytes}; use zerocopy_derive::*; use crate::{ new_base::{ - parse::{ParseError, SplitFrom, SplitFromMessage}, + parse::{ + ParseError, ParseFrom, ParseFromMessage, SplitFrom, + SplitFromMessage, + }, Message, }, new_rdata::Opt, @@ -49,6 +52,19 @@ impl<'a> SplitFromMessage<'a> for EdnsRecord<'a> { } } +impl<'a> ParseFromMessage<'a> for EdnsRecord<'a> { + fn parse_from_message( + message: &'a Message, + range: Range, + ) -> Result { + message + .as_bytes() + .get(range) + .ok_or(ParseError) + .and_then(Self::parse_from) + } +} + //--- Parsing from bytes impl<'a> SplitFrom<'a> for EdnsRecord<'a> { @@ -79,6 +95,31 @@ impl<'a> SplitFrom<'a> for EdnsRecord<'a> { } } +impl<'a> ParseFrom<'a> for EdnsRecord<'a> { + fn parse_from(bytes: &'a [u8]) -> Result { + // Strip the record name (root) and the record type. + let rest = bytes.strip_prefix(&[0, 0, 41]).ok_or(ParseError)?; + + let (&max_udp_payload, rest) = <&U16>::split_from(rest)?; + let (&ext_rcode, rest) = <&u8>::split_from(rest)?; + let (&version, rest) = <&u8>::split_from(rest)?; + let (&flags, rest) = <&EdnsFlags>::split_from(rest)?; + + // Split the record size and data. + let (&size, rest) = <&U16>::split_from(rest)?; + let size: usize = size.get().into(); + let options = Opt::ref_from_bytes_with_elems(rest, size)?; + + Ok(Self { + max_udp_payload, + ext_rcode, + version, + flags, + options, + }) + } +} + //----------- EdnsFlags ------------------------------------------------------ /// Extended DNS flags describing a message. @@ -181,8 +222,8 @@ pub struct OptionCode { //----------- UnknownOption -------------------------------------------------- /// Data for an unknown Extended DNS option. -#[derive(Debug, FromBytes, IntoBytes, Immutable, Unaligned)] -#[repr(transparent)] +#[derive(Debug, FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C)] pub struct UnknownOption { /// The unparsed option data. pub octets: [u8], From ff7d9136b69b8bca98d26366d0ce0f2b86bf1ff7 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 23:51:57 +0100 Subject: [PATCH 045/167] [new_base/record] Add a default for 'parse_record_data()' --- src/new_base/record.rs | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 6f3a1daa8..93038f8d9 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -297,7 +297,10 @@ pub trait ParseRecordData<'a>: Sized { message: &'a Message, range: Range, rtype: RType, - ) -> Result; + ) -> Result { + let bytes = message.as_bytes().get(range).ok_or(ParseError)?; + Self::parse_record_data_bytes(bytes, rtype) + } /// Parse DNS record data of the given type from a byte string. fn parse_record_data_bytes( @@ -331,15 +334,6 @@ impl UnparsedRecordData { //--- Parsing record data impl<'a> ParseRecordData<'a> for &'a UnparsedRecordData { - fn parse_record_data( - message: &'a Message, - range: Range, - rtype: RType, - ) -> Result { - let bytes = message.as_bytes().get(range).ok_or(ParseError)?; - Self::parse_record_data_bytes(bytes, rtype) - } - fn parse_record_data_bytes( bytes: &'a [u8], _rtype: RType, From ba260f0e60f830671ced8313a9e4d92290ca0e68 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 23:53:00 +0100 Subject: [PATCH 046/167] [new_rdata/basic] Use more capitalization in record data type names --- src/new_rdata/basic.rs | 63 ++++++++++++++++++++++++++++++++++-------- src/new_rdata/mod.rs | 2 +- 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 2abb38cf4..cf6a99c39 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -164,14 +164,14 @@ impl BuildInto for Ns { /// The canonical name for this domain. #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] -pub struct Cname { +pub struct CName { /// The canonical name. pub name: N, } //--- Parsing from DNS messages -impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Cname { +impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for CName { fn parse_from_message( message: &'a Message, range: Range, @@ -182,7 +182,7 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Cname { //--- Building into DNS messages -impl BuildIntoMessage for Cname { +impl BuildIntoMessage for CName { fn build_into_message( &self, builder: build::Builder<'_>, @@ -193,7 +193,7 @@ impl BuildIntoMessage for Cname { //--- Parsing from bytes -impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Cname { +impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for CName { fn parse_from(bytes: &'a [u8]) -> Result { N::parse_from(bytes).map(|name| Self { name }) } @@ -201,7 +201,7 @@ impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Cname { //--- Building into bytes -impl BuildInto for Cname { +impl BuildInto for CName { fn build_into<'b>( &self, bytes: &'b mut [u8], @@ -442,11 +442,11 @@ impl BuildInto for Ptr { } } -//----------- Hinfo ---------------------------------------------------------- +//----------- HInfo ---------------------------------------------------------- /// Information about the host computer. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Hinfo<'a> { +pub struct HInfo<'a> { /// The CPU type. pub cpu: &'a CharStr, @@ -456,7 +456,7 @@ pub struct Hinfo<'a> { //--- Parsing from DNS messages -impl<'a> ParseFromMessage<'a> for Hinfo<'a> { +impl<'a> ParseFromMessage<'a> for HInfo<'a> { fn parse_from_message( message: &'a Message, range: Range, @@ -471,7 +471,7 @@ impl<'a> ParseFromMessage<'a> for Hinfo<'a> { //--- Building into DNS messages -impl BuildIntoMessage for Hinfo<'_> { +impl BuildIntoMessage for HInfo<'_> { fn build_into_message( &self, mut builder: build::Builder<'_>, @@ -485,7 +485,7 @@ impl BuildIntoMessage for Hinfo<'_> { //--- Parsing from bytes -impl<'a> ParseFrom<'a> for Hinfo<'a> { +impl<'a> ParseFrom<'a> for HInfo<'a> { fn parse_from(bytes: &'a [u8]) -> Result { let (cpu, rest) = <&CharStr>::split_from(bytes)?; let os = <&CharStr>::parse_from(rest)?; @@ -495,7 +495,7 @@ impl<'a> ParseFrom<'a> for Hinfo<'a> { //--- Building into bytes -impl BuildInto for Hinfo<'_> { +impl BuildInto for HInfo<'_> { fn build_into<'b>( &self, mut bytes: &'b mut [u8], @@ -586,7 +586,23 @@ pub struct Txt { content: [u8], } -// TODO: Support for iterating over the contained 'CharStr's. +//--- Interaction + +impl Txt { + /// Iterate over the [`CharStr`]s in this record. + pub fn iter<'a>( + &'a self, + ) -> impl Iterator> + 'a { + // NOTE: A TXT record always has at least one 'CharStr' within. + let first = <&CharStr>::split_from(&self.content); + core::iter::successors(Some(first), |prev| { + prev.as_ref() + .ok() + .map(|(_elem, rest)| <&CharStr>::split_from(rest)) + }) + .map(|result| result.map(|(elem, _rest)| elem)) + } +} //--- Parsing from DNS messages @@ -639,3 +655,26 @@ impl BuildInto for Txt { self.content.build_into(bytes) } } + +//--- Formatting + +impl fmt::Debug for Txt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Content<'a>(&'a Txt); + impl fmt::Debug for Content<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut list = f.debug_list(); + for elem in self.0.iter() { + if let Ok(elem) = elem { + list.entry(&elem); + } else { + list.entry(&ParseError); + } + } + list.finish() + } + } + + f.debug_tuple("Txt").field(&Content(self)).finish() + } +} diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 1aad4cca0..8fb32032f 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -1,7 +1,7 @@ //! Record data types. mod basic; -pub use basic::{Cname, Hinfo, Mx, Ns, Ptr, Soa, Txt, Wks, A}; +pub use basic::{CName, HInfo, Mx, Ns, Ptr, Soa, Txt, Wks, A}; mod ipv6; pub use ipv6::Aaaa; From 6d358a3e81ef08142511fed1a1f462fee5eeed98 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 30 Dec 2024 23:53:36 +0100 Subject: [PATCH 047/167] [new_rdata] Define enum 'RecordData' --- src/new_base/record.rs | 34 ++++++++ src/new_rdata/mod.rs | 176 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 210 insertions(+) diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 93038f8d9..6354611ed 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -240,6 +240,40 @@ pub struct RType { pub code: U16, } +//--- Associated Constants + +impl RType { + /// The type of an [`A`](crate::new_rdata::A) record. + pub const A: Self = Self { code: U16::new(1) }; + + /// The type of an [`Ns`](crate::new_rdata::Ns) record. + pub const NS: Self = Self { code: U16::new(2) }; + + /// The type of a [`CName`](crate::new_rdata::CName) record. + pub const CNAME: Self = Self { code: U16::new(5) }; + + /// The type of an [`Soa`](crate::new_rdata::Soa) record. + pub const SOA: Self = Self { code: U16::new(6) }; + + /// The type of a [`Wks`](crate::new_rdata::Wks) record. + pub const WKS: Self = Self { code: U16::new(11) }; + + /// The type of a [`Ptr`](crate::new_rdata::Ptr) record. + pub const PTR: Self = Self { code: U16::new(12) }; + + /// The type of a [`HInfo`](crate::new_rdata::HInfo) record. + pub const HINFO: Self = Self { code: U16::new(13) }; + + /// The type of a [`Mx`](crate::new_rdata::Mx) record. + pub const MX: Self = Self { code: U16::new(15) }; + + /// The type of a [`Txt`](crate::new_rdata::Txt) record. + pub const TXT: Self = Self { code: U16::new(16) }; + + /// The type of an [`Aaaa`](crate::new_rdata::Aaaa) record. + pub const AAAA: Self = Self { code: U16::new(28) }; +} + //----------- RClass --------------------------------------------------------- /// The class of a record. diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 8fb32032f..67f4d9cb3 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -1,5 +1,19 @@ //! Record data types. +use core::ops::Range; + +use zerocopy_derive::*; + +use crate::new_base::{ + build::{BuildInto, BuildIntoMessage, Builder, TruncationError}, + parse::{ + ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + }, + Message, ParseRecordData, RType, +}; + +//----------- Concrete record data types ------------------------------------- + mod basic; pub use basic::{CName, HInfo, Mx, Ns, Ptr, Soa, Txt, Wks, A}; @@ -8,3 +22,165 @@ pub use ipv6::Aaaa; mod edns; pub use edns::Opt; + +//----------- RecordData ----------------------------------------------------- + +/// DNS record data. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub enum RecordData<'a, N> { + /// The IPv4 address of a host responsible for this domain. + A(&'a A), + + /// The authoritative name server for this domain. + Ns(Ns), + + /// The canonical name for this domain. + CName(CName), + + /// The start of a zone of authority. + Soa(Soa), + + /// Well-known services supported on this domain. + Wks(&'a Wks), + + /// A pointer to another domain name. + Ptr(Ptr), + + /// Information about the host computer. + HInfo(HInfo<'a>), + + /// A host that can exchange mail for this domain. + Mx(Mx), + + /// Free-form text strings about this domain. + Txt(&'a Txt), + + /// The IPv6 address of a host responsible for this domain. + Aaaa(&'a Aaaa), + + /// Data for an unknown DNS record type. + Unknown(RType, &'a UnknownRecordData), +} + +//--- Parsing record data + +impl<'a, N> ParseRecordData<'a> for RecordData<'a, N> +where + N: SplitFrom<'a> + SplitFromMessage<'a>, +{ + fn parse_record_data( + message: &'a Message, + range: Range, + rtype: RType, + ) -> Result { + match rtype { + RType::A => <&A>::parse_from_message(message, range).map(Self::A), + RType::NS => Ns::parse_from_message(message, range).map(Self::Ns), + RType::CNAME => { + CName::parse_from_message(message, range).map(Self::CName) + } + RType::SOA => { + Soa::parse_from_message(message, range).map(Self::Soa) + } + RType::WKS => { + <&Wks>::parse_from_message(message, range).map(Self::Wks) + } + RType::PTR => { + Ptr::parse_from_message(message, range).map(Self::Ptr) + } + RType::HINFO => { + HInfo::parse_from_message(message, range).map(Self::HInfo) + } + RType::MX => Mx::parse_from_message(message, range).map(Self::Mx), + RType::TXT => { + <&Txt>::parse_from_message(message, range).map(Self::Txt) + } + RType::AAAA => { + <&Aaaa>::parse_from_message(message, range).map(Self::Aaaa) + } + _ => <&UnknownRecordData>::parse_from_message(message, range) + .map(|data| Self::Unknown(rtype, data)), + } + } + + fn parse_record_data_bytes( + bytes: &'a [u8], + rtype: RType, + ) -> Result { + match rtype { + RType::A => <&A>::parse_from(bytes).map(Self::A), + RType::NS => Ns::parse_from(bytes).map(Self::Ns), + RType::CNAME => CName::parse_from(bytes).map(Self::CName), + RType::SOA => Soa::parse_from(bytes).map(Self::Soa), + RType::WKS => <&Wks>::parse_from(bytes).map(Self::Wks), + RType::PTR => Ptr::parse_from(bytes).map(Self::Ptr), + RType::HINFO => HInfo::parse_from(bytes).map(Self::HInfo), + RType::MX => Mx::parse_from(bytes).map(Self::Mx), + RType::TXT => <&Txt>::parse_from(bytes).map(Self::Txt), + RType::AAAA => <&Aaaa>::parse_from(bytes).map(Self::Aaaa), + _ => <&UnknownRecordData>::parse_from(bytes) + .map(|data| Self::Unknown(rtype, data)), + } + } +} + +//--- Building record data + +impl<'a, N> BuildIntoMessage for RecordData<'a, N> +where + N: BuildIntoMessage, +{ + fn build_into_message( + &self, + builder: Builder<'_>, + ) -> Result<(), TruncationError> { + match self { + Self::A(r) => r.build_into_message(builder), + Self::Ns(r) => r.build_into_message(builder), + Self::CName(r) => r.build_into_message(builder), + Self::Soa(r) => r.build_into_message(builder), + Self::Wks(r) => r.build_into_message(builder), + Self::Ptr(r) => r.build_into_message(builder), + Self::HInfo(r) => r.build_into_message(builder), + Self::Txt(r) => r.build_into_message(builder), + Self::Aaaa(r) => r.build_into_message(builder), + Self::Mx(r) => r.build_into_message(builder), + Self::Unknown(_, r) => r.octets.build_into_message(builder), + } + } +} + +impl<'a, N> BuildInto for RecordData<'a, N> +where + N: BuildInto, +{ + fn build_into<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + match self { + Self::A(r) => r.build_into(bytes), + Self::Ns(r) => r.build_into(bytes), + Self::CName(r) => r.build_into(bytes), + Self::Soa(r) => r.build_into(bytes), + Self::Wks(r) => r.build_into(bytes), + Self::Ptr(r) => r.build_into(bytes), + Self::HInfo(r) => r.build_into(bytes), + Self::Txt(r) => r.build_into(bytes), + Self::Aaaa(r) => r.build_into(bytes), + Self::Mx(r) => r.build_into(bytes), + Self::Unknown(_, r) => r.octets.build_into(bytes), + } + } +} + +//----------- UnknownRecordData ---------------------------------------------- + +/// Data for an unknown DNS record type. +#[derive(Debug, FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[repr(C)] +pub struct UnknownRecordData { + /// The unparsed option data. + pub octets: [u8], +} From bd08a473ff7f1fd1983769a7cf1d3150ab6d638b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 31 Dec 2024 00:00:30 +0100 Subject: [PATCH 048/167] [new_rdata/basic] Elide lifetime as per Clippy --- src/new_rdata/basic.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index cf6a99c39..4807784b3 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -590,9 +590,9 @@ pub struct Txt { impl Txt { /// Iterate over the [`CharStr`]s in this record. - pub fn iter<'a>( - &'a self, - ) -> impl Iterator> + 'a { + pub fn iter( + &self, + ) -> impl Iterator> + '_ { // NOTE: A TXT record always has at least one 'CharStr' within. let first = <&CharStr>::split_from(&self.content); core::iter::successors(Some(first), |prev| { From 54f131effb1a84457d52d4bcc91c840f959fd19a Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 31 Dec 2024 11:56:00 +0100 Subject: [PATCH 049/167] [new_rdata] Elide lifetimes as per clippy --- src/new_rdata/mod.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 67f4d9cb3..3228608cc 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -127,10 +127,7 @@ where //--- Building record data -impl<'a, N> BuildIntoMessage for RecordData<'a, N> -where - N: BuildIntoMessage, -{ +impl BuildIntoMessage for RecordData<'_, N> { fn build_into_message( &self, builder: Builder<'_>, @@ -151,10 +148,7 @@ where } } -impl<'a, N> BuildInto for RecordData<'a, N> -where - N: BuildInto, -{ +impl BuildInto for RecordData<'_, N> { fn build_into<'b>( &self, bytes: &'b mut [u8], From 6881f6ad4b0de73acf651ce0ab281a673cc177c9 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 1 Jan 2025 11:35:09 +0100 Subject: [PATCH 050/167] Set up a 'domain-macros' crate 'domain-macros' will provide: - 'derive' macros for DNS-specific (zero-copy) serialization - 'derive' macros for building and parsing specialized DNS messages - 'derive' macros for composing clients and servers The first use case will replace 'zerocopy'. --- Cargo.lock | 10 ++++++++++ Cargo.toml | 10 ++++++---- macros/Cargo.toml | 23 +++++++++++++++++++++++ macros/src/lib.rs | 3 +++ 4 files changed, 42 insertions(+), 4 deletions(-) create mode 100644 macros/Cargo.toml create mode 100644 macros/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 7506702e2..d9833efa5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -245,6 +245,7 @@ dependencies = [ "arc-swap", "bytes", "chrono", + "domain-macros", "futures-util", "hashbrown 0.14.5", "heapless", @@ -282,6 +283,15 @@ dependencies = [ "zerocopy-derive 0.8.13", ] +[[package]] +name = "domain-macros" +version = "0.10.3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "equivalent" version = "1.0.1" diff --git a/Cargo.toml b/Cargo.toml index 9d078a5e5..041d83731 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,7 @@ +[workspace] +resolver = "2" +members = [".", "./macros"] + [package] name = "domain" version = "0.10.3" @@ -16,11 +20,9 @@ readme = "README.md" keywords = ["DNS", "domain"] license = "BSD-3-Clause" -[lib] -name = "domain" -path = "src/lib.rs" - [dependencies] +domain-macros = { path = "./macros", version = "0.10.3" } + arbitrary = { version = "1.4.1", optional = true, features = ["derive"] } octseq = { version = "0.5.2", default-features = false } time = { version = "0.3.1", default-features = false } diff --git a/macros/Cargo.toml b/macros/Cargo.toml new file mode 100644 index 000000000..b94e60bbd --- /dev/null +++ b/macros/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "domain-macros" + +# Copied from 'domain'. +version = "0.10.3" +rust-version = "1.68.2" +edition = "2021" + +authors = ["NLnet Labs "] +description = "Procedural macros for the `domain` crate." +documentation = "https://docs.rs/domain-macros" +homepage = "https://github.com/nlnetlabs/domain/" +repository = "https://github.com/nlnetlabs/domain/" +keywords = ["DNS", "domain"] +license = "BSD-3-Clause" + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0" +syn = "2.0" +quote = "1.0" diff --git a/macros/src/lib.rs b/macros/src/lib.rs new file mode 100644 index 000000000..0bf081d73 --- /dev/null +++ b/macros/src/lib.rs @@ -0,0 +1,3 @@ +//! Procedural macros for [`domain`]. +//! +//! [`domain`]: https://docs.rs/domain From 155fb5acabbde4a0258f9e0a2377e8e0fee76169 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 1 Jan 2025 12:32:37 +0100 Subject: [PATCH 051/167] [new_base/parse] Define 'ParseBytesByRef' for deriving --- src/new_base/parse/mod.rs | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 31ad191ba..920395bd9 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -97,6 +97,33 @@ pub trait ParseFrom<'a>: Sized { fn parse_from(bytes: &'a [u8]) -> Result; } +/// Zero-copy parsing from a byte string. +/// +/// # Safety +/// +/// Every implementation of [`ParseBytesByRef`] must satisfy the invariants +/// documented on [`parse_bytes_by_ref()`]. An incorrect implementation is +/// considered to cause undefined behaviour. +/// +/// Implementing types should almost always be unaligned, but foregoing this +/// will not cause undefined behaviour (however, it will be very confusing for +/// users). +pub unsafe trait ParseBytesByRef { + /// Interpret a byte string as an instance of [`Self`]. + /// + /// The byte string will be validated and re-interpreted as a reference to + /// [`Self`]. The whole byte string will be used. If the input is not a + /// valid instance of [`Self`], a [`ParseError`] is returned. + /// + /// ## Invariants + /// + /// For the statement `let this: &T = T::parse_bytes_by_ref(bytes)?;`, + /// + /// - `bytes.as_ptr() == this as *const T as *const u8`. + /// - `bytes.len() == core::mem::size_of_val(this)`. + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError>; +} + //--- Carrying over 'zerocopy' traits // NOTE: We can't carry over 'read_from_prefix' because the trait impls would From cf170e6af2f433fcd9d6ddfc8de5ab6212f811e4 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 1 Jan 2025 19:25:00 +0100 Subject: [PATCH 052/167] [macros] Define 'ImplSkeleton' and prepare a basic derive --- macros/Cargo.toml | 13 ++- macros/src/impls.rs | 237 ++++++++++++++++++++++++++++++++++++++++++++ macros/src/lib.rs | 52 ++++++++++ src/lib.rs | 8 ++ 4 files changed, 306 insertions(+), 4 deletions(-) create mode 100644 macros/src/impls.rs diff --git a/macros/Cargo.toml b/macros/Cargo.toml index b94e60bbd..263db27af 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -17,7 +17,12 @@ license = "BSD-3-Clause" [lib] proc-macro = true -[dependencies] -proc-macro2 = "1.0" -syn = "2.0" -quote = "1.0" +[dependencies.proc-macro2] +version = "1.0" + +[dependencies.syn] +version = "2.0" +features = ["visit"] + +[dependencies.quote] +version = "1.0" diff --git a/macros/src/impls.rs b/macros/src/impls.rs new file mode 100644 index 000000000..c5af737fb --- /dev/null +++ b/macros/src/impls.rs @@ -0,0 +1,237 @@ +//! Helpers for generating `impl` blocks. + +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::{punctuated::Punctuated, visit::Visit, *}; + +//----------- ImplSkeleton --------------------------------------------------- + +/// The skeleton of an `impl` block. +pub struct ImplSkeleton { + /// Lifetime parameters for the `impl` block. + pub lifetimes: Vec, + + /// Type parameters for the `impl` block. + pub types: Vec, + + /// Const generic parameters for the `impl` block. + pub consts: Vec, + + /// Whether the `impl` is unsafe. + pub unsafety: Option, + + /// The trait being implemented. + pub bound: Path, + + /// The type being implemented on. + pub subject: Path, + + /// The where clause of the `impl` block. + pub where_clause: WhereClause, + + /// The contents of the `impl`. + pub contents: Block, + + /// A `const` block for asserting requirements. + pub requirements: Block, +} + +impl ImplSkeleton { + /// Construct an [`ImplSkeleton`] for a [`DeriveInput`]. + pub fn new(input: &DeriveInput, unsafety: bool, bound: Path) -> Self { + let mut lifetimes = Vec::new(); + let mut types = Vec::new(); + let mut consts = Vec::new(); + let mut subject_args = Punctuated::new(); + + for param in &input.generics.params { + match param { + GenericParam::Lifetime(value) => { + lifetimes.push(value.clone()); + let id = value.lifetime.clone(); + subject_args.push(GenericArgument::Lifetime(id)); + } + + GenericParam::Type(value) => { + types.push(value.clone()); + let id = value.ident.clone(); + let id = TypePath { + qself: None, + path: Path { + leading_colon: None, + segments: [PathSegment { + ident: id, + arguments: PathArguments::None, + }] + .into_iter() + .collect(), + }, + }; + subject_args.push(GenericArgument::Type(id.into())); + } + + GenericParam::Const(value) => { + consts.push(value.clone()); + let id = value.ident.clone(); + let id = TypePath { + qself: None, + path: Path { + leading_colon: None, + segments: [PathSegment { + ident: id, + arguments: PathArguments::None, + }] + .into_iter() + .collect(), + }, + }; + subject_args.push(GenericArgument::Type(id.into())); + } + } + } + + let unsafety = unsafety.then_some(::default()); + + let subject = Path { + leading_colon: None, + segments: [PathSegment { + ident: input.ident.clone(), + arguments: PathArguments::AngleBracketed( + AngleBracketedGenericArguments { + colon2_token: None, + lt_token: Default::default(), + args: subject_args, + gt_token: Default::default(), + }, + ), + }] + .into_iter() + .collect(), + }; + + let where_clause = + input.generics.where_clause.clone().unwrap_or(WhereClause { + where_token: Default::default(), + predicates: Punctuated::new(), + }); + + let contents = Block { + brace_token: Default::default(), + stmts: Vec::new(), + }; + + let requirements = Block { + brace_token: Default::default(), + stmts: Vec::new(), + }; + + Self { + lifetimes, + types, + consts, + unsafety, + bound, + subject, + where_clause, + contents, + requirements, + } + } + + /// Require a bound for a type. + /// + /// If the type is concrete, a verifying statement is added for it. + /// Otherwise, it is added to the where clause. + pub fn require_bound(&mut self, target: Type, bound: TypeParamBound) { + if self.is_concrete(&target) { + // Add a concrete requirement for this bound. + self.requirements.stmts.push(parse_quote! { + const _: fn() = || { + fn assert_impl() {} + assert_impl::<#target>(); + }; + }); + } else { + // Add this bound to the `where` clause. + let mut bounds = Punctuated::new(); + bounds.push_value(bound); + let pred = WherePredicate::Type(PredicateType { + lifetimes: None, + bounded_ty: target, + colon_token: Default::default(), + bounds, + }); + self.where_clause.predicates.push_value(pred); + } + } + + /// Whether a type is concrete within this `impl` block. + pub fn is_concrete(&self, target: &Type) -> bool { + struct ConcretenessVisitor<'a> { + /// The `impl` skeleton being added to. + skeleton: &'a ImplSkeleton, + + /// Whether the visited type is concrete. + is_concrete: bool, + } + + impl<'ast> Visit<'ast> for ConcretenessVisitor<'_> { + fn visit_lifetime(&mut self, i: &'ast Lifetime) { + self.is_concrete = self.is_concrete + && self + .skeleton + .lifetimes + .iter() + .all(|l| l.lifetime != *i); + } + + fn visit_ident(&mut self, i: &'ast Ident) { + self.is_concrete = self.is_concrete + && self.skeleton.types.iter().all(|t| t.ident != *i); + self.is_concrete = self.is_concrete + && self.skeleton.consts.iter().all(|c| c.ident != *i); + } + } + + let mut visitor = ConcretenessVisitor { + skeleton: self, + is_concrete: true, + }; + + visitor.visit_type(target); + + visitor.is_concrete + } +} + +impl ToTokens for ImplSkeleton { + fn to_tokens(&self, tokens: &mut TokenStream) { + let Self { + lifetimes, + types, + consts, + unsafety, + bound, + subject, + where_clause, + contents, + requirements, + } = self; + + quote! { + #unsafety + impl<#(#lifetimes,)* #(#types,)* #(#consts,)*> + #bound for #subject + #where_clause + #contents + } + .to_tokens(tokens); + + if !requirements.stmts.is_empty() { + quote! { + const _: () = #requirements; + } + .to_tokens(tokens); + } + } +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 0bf081d73..68f956adc 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,3 +1,55 @@ //! Procedural macros for [`domain`]. //! //! [`domain`]: https://docs.rs/domain + +use proc_macro as pm; +use proc_macro2::TokenStream; +use quote::ToTokens; +use syn::*; + +mod impls; +use impls::ImplSkeleton; + +//----------- ParseBytesByRef ------------------------------------------------ + +#[proc_macro_derive(ParseBytesByRef)] +pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: DeriveInput) -> Result { + let bound = parse_quote!(::domain::new_base::parse::ParseBytesByRef); + let mut skeleton = ImplSkeleton::new(&input, true, bound); + + let data = match input.data { + Data::Struct(data) => data, + Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'ParseBytesByRef' can only be 'derive'd for 'struct's", + )); + } + Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'ParseBytesByRef' can only be 'derive'd for 'struct's", + )); + } + }; + + // TODO: Ensure that the type is 'repr(C)' or 'repr(transparent)'. + + // Every field must implement 'ParseBytesByRef'. + for field in data.fields.iter() { + let bound = + parse_quote!(::domain::new_base::parse::ParseBytesByRef); + skeleton.require_bound(field.ty.clone(), bound); + } + + // TODO: Implement 'parse_bytes_by_ref()' in 'skeleton.contents'. + + Ok(skeleton.into_token_stream().into()) + } + + let input = syn::parse_macro_input!(input as DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} diff --git a/src/lib.rs b/src/lib.rs index 4d08972f0..6fe1aeec9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -190,6 +190,14 @@ extern crate std; #[macro_use] extern crate core; +// The 'domain-macros' crate introduces 'derive' macros which can be used by +// users of the 'domain' crate, but also by the 'domain' crate itself. Within +// those macros, references to declarations in the 'domain' crate are written +// as '::domain::*' ... but this doesn't work when those proc macros are used +// by the 'domain' crate itself. The alias introduced here fixes this: now +// '::domain' means the same thing within this crate as in dependents of it. +extern crate self as domain; + pub mod base; pub mod dep; pub mod net; From 89ee79785dd273a356c8f180592bc7308f0269d8 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 12:58:04 +0100 Subject: [PATCH 053/167] Expand 'ParseBytesByRef' and largely finish its derive macro --- macros/src/lib.rs | 104 +++++++++++++++++++++++--- src/lib.rs | 9 ++- src/new_base/parse/mod.rs | 153 +++++++++++++++++++++++++++++++++++++- 3 files changed, 249 insertions(+), 17 deletions(-) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 68f956adc..255567046 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -4,7 +4,8 @@ use proc_macro as pm; use proc_macro2::TokenStream; -use quote::ToTokens; +use quote::{quote, ToTokens}; +use spanned::Spanned; use syn::*; mod impls; @@ -15,10 +16,7 @@ use impls::ImplSkeleton; #[proc_macro_derive(ParseBytesByRef)] pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { fn inner(input: DeriveInput) -> Result { - let bound = parse_quote!(::domain::new_base::parse::ParseBytesByRef); - let mut skeleton = ImplSkeleton::new(&input, true, bound); - - let data = match input.data { + let data = match &input.data { Data::Struct(data) => data, Data::Enum(data) => { return Err(Error::new_spanned( @@ -36,14 +34,98 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // TODO: Ensure that the type is 'repr(C)' or 'repr(transparent)'. - // Every field must implement 'ParseBytesByRef'. - for field in data.fields.iter() { - let bound = - parse_quote!(::domain::new_base::parse::ParseBytesByRef); - skeleton.require_bound(field.ty.clone(), bound); + // Split up the last field from the rest. + let mut fields = data.fields.iter(); + let Some(last) = fields.next_back() else { + // This type has no fields. Return a simple implementation. + let (impl_generics, ty_generics, where_clause) = + input.generics.split_for_impl(); + let name = input.ident; + + return Ok(quote! { + impl #impl_generics + ::domain::new_base::parse::ParseBytesByRef + for #name #ty_generics + #where_clause { + fn parse_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + &Self, + ::domain::new_base::parse::ParseError, + > { + Ok(unsafe { &*bytes.as_ptr().cast::() }) + } + + fn ptr_with_address( + &self, + addr: *const (), + ) -> *const Self { + addr.cast() + } + } + }); + }; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let bound = parse_quote!(::domain::new_base::parse::ParseBytesByRef); + let mut skeleton = ImplSkeleton::new(&input, true, bound); + + // Establish bounds on the fields. + for field in fields.clone() { + // This field should implement 'SplitBytesByRef'. + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::parse::SplitBytesByRef), + ); } + // The last field should implement 'ParseBytesByRef'. + skeleton.require_bound( + last.ty.clone(), + parse_quote!(::domain::new_base::parse::ParseBytesByRef), + ); - // TODO: Implement 'parse_bytes_by_ref()' in 'skeleton.contents'. + // Define 'parse_bytes_by_ref()'. + let tys = fields.clone().map(|f| &f.ty); + let last_ty = &last.ty; + skeleton.contents.stmts.push(parse_quote! { + fn parse_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + &Self, + ::domain::new_base::parse::ParseError, + > { + let start = bytes.as_ptr(); + #(let (_, bytes) = + <#tys as ::domain::new_base::parse::SplitBytesByRef> + ::split_bytes_by_ref(bytes)?;)* + let last = + <#last_ty as ::domain::new_base::parse::ParseBytesByRef> + ::parse_bytes_by_ref(bytes)?; + let ptr = + <#last_ty as ::domain::new_base::parse::ParseBytesByRef> + ::ptr_with_address(last, start as *const ()); + + // SAFETY: + // - By + Ok(unsafe { &*(ptr as *const Self) }) + } + }); + + // Define 'ptr_with_address()'. + let last_name = match last.ident.as_ref() { + Some(ident) => Member::Named(ident.clone()), + None => Member::Unnamed(Index { + index: data.fields.len() as u32 - 1, + span: last.ty.span(), + }), + }; + skeleton.contents.stmts.push(parse_quote! { + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + <#last_ty as ::domain::new_base::parse::ParseBytesByRef> + ::ptr_with_address(&self.#last_name, addr) + as *const Self + } + }); Ok(skeleton.into_token_stream().into()) } diff --git a/src/lib.rs b/src/lib.rs index 6fe1aeec9..40b4efd7a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -187,17 +187,18 @@ #[macro_use] extern crate std; -#[macro_use] -extern crate core; - // The 'domain-macros' crate introduces 'derive' macros which can be used by // users of the 'domain' crate, but also by the 'domain' crate itself. Within // those macros, references to declarations in the 'domain' crate are written // as '::domain::*' ... but this doesn't work when those proc macros are used -// by the 'domain' crate itself. The alias introduced here fixes this: now +// in the 'domain' crate itself. The alias introduced here fixes this: now // '::domain' means the same thing within this crate as in dependents of it. extern crate self as domain; +// Re-export 'core' for use in macros. +#[doc(hidden)] +pub use core as __core; + pub mod base; pub mod dep; pub mod net; diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 920395bd9..b218646c9 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -97,13 +97,47 @@ pub trait ParseFrom<'a>: Sized { fn parse_from(bytes: &'a [u8]) -> Result; } +/// Zero-copy parsing from the start of a byte string. +/// +/// # Safety +/// +/// Every implementation of [`SplitBytesByRef`] must satisfy the invariants +/// documented on [`split_bytes_by_ref()`]. An incorrect implementation is +/// considered to cause undefined behaviour. +/// +/// Implementing types should almost always be unaligned, but foregoing this +/// will not cause undefined behaviour (however, it will be very confusing for +/// users). +pub unsafe trait SplitBytesByRef: ParseBytesByRef { + /// Interpret a byte string as an instance of [`Self`]. + /// + /// The byte string will be validated and re-interpreted as a reference to + /// [`Self`]. The length of [`Self`] will be determined, possibly based + /// on the contents (but not the length!) of the input, and the remaining + /// bytes will be returned. If the input does not begin with a valid + /// instance of [`Self`], a [`ParseError`] is returned. + /// + /// ## Invariants + /// + /// For the statement `let (this, rest) = T::split_bytes_by_ref(bytes)?;`, + /// + /// - `bytes.as_ptr() == this as *const T as *const u8`. + /// - `bytes.len() == core::mem::size_of_val(this) + rest.len()`. + /// - `bytes.as_ptr().offset(size_of_val(this)) == rest.as_ptr()`. + fn split_bytes_by_ref(bytes: &[u8]) + -> Result<(&Self, &[u8]), ParseError>; +} + /// Zero-copy parsing from a byte string. /// /// # Safety /// /// Every implementation of [`ParseBytesByRef`] must satisfy the invariants -/// documented on [`parse_bytes_by_ref()`]. An incorrect implementation is -/// considered to cause undefined behaviour. +/// documented on [`parse_bytes_by_ref()`] and [`ptr_with_address()`]. An +/// incorrect implementation is considered to cause undefined behaviour. +/// +/// [`parse_bytes_by_ref()`]: Self::parse_bytes_by_ref() +/// [`ptr_with_address()`]: Self::ptr_with_address() /// /// Implementing types should almost always be unaligned, but foregoing this /// will not cause undefined behaviour (however, it will be very confusing for @@ -122,6 +156,121 @@ pub unsafe trait ParseBytesByRef { /// - `bytes.as_ptr() == this as *const T as *const u8`. /// - `bytes.len() == core::mem::size_of_val(this)`. fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError>; + + /// Change the address of a pointer to [`Self`]. + /// + /// When [`Self`] is used as the last field in a type that also implements + /// [`ParseBytesByRef`], it may be dynamically sized, and so a pointer (or + /// reference) to it may include additional metadata. This metadata is + /// included verbatim in any reference/pointer to the containing type. + /// + /// When the containing type implements [`ParseBytesByRef`], it needs to + /// construct a reference/pointer to itself, which includes this metadata. + /// Rust does not currently offer a general way to extract this metadata + /// or pair it with another address, so this function is necessary. The + /// caller can construct a reference to [`Self`], then change its address + /// to point to the containing type, then cast that pointer to the right + /// type. + /// + /// # Implementing + /// + /// Most users will derive [`ParseBytesByRef`] and so don't need to worry + /// about this. For manual implementations: + /// + /// In the future, an adequate default implementation for this function + /// may be provided. Until then, it should be implemented using one of + /// the following expressions: + /// + /// ```ignore + /// fn ptr_with_address( + /// &self, + /// addr: *const (), + /// ) -> *const Self { + /// // If 'Self' is Sized: + /// addr.cast::() + /// + /// // If 'Self' is an aggregate whose last field is 'last': + /// self.last.ptr_with_address(addr) as *const Self + /// } + /// ``` + /// + /// # Invariants + /// + /// For the statement `let result = Self::ptr_with_address(ptr, addr);`: + /// + /// - `result as usize == addr as usize`. + /// - `core::ptr::metadata(result) == core::ptr::metadata(ptr)`. + fn ptr_with_address(&self, addr: *const ()) -> *const Self; +} + +unsafe impl SplitBytesByRef for u8 { + fn split_bytes_by_ref( + bytes: &[u8], + ) -> Result<(&Self, &[u8]), ParseError> { + bytes.split_first().ok_or(ParseError) + } +} + +unsafe impl ParseBytesByRef for u8 { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + let [result] = bytes else { + return Err(ParseError); + }; + + return Ok(result); + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + addr.cast() + } +} + +unsafe impl ParseBytesByRef for [u8] { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + Ok(bytes) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + core::ptr::slice_from_raw_parts(addr.cast(), self.len()) + } +} + +unsafe impl SplitBytesByRef for [u8; N] { + fn split_bytes_by_ref( + bytes: &[u8], + ) -> Result<(&Self, &[u8]), ParseError> { + if bytes.len() < N { + Err(ParseError) + } else { + let (bytes, rest) = bytes.split_at(N); + + // SAFETY: + // - It is known that 'bytes.len() == N'. + // - Thus '&bytes' has the same layout as '[u8; N]'. + // - Thus it is safe to cast a pointer to it to '[u8; N]'. + // - The referenced data has the same lifetime as the output. + Ok((unsafe { &*bytes.as_ptr().cast::<[u8; N]>() }, rest)) + } + } +} + +unsafe impl ParseBytesByRef for [u8; N] { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + if bytes.len() != N { + Err(ParseError) + } else { + // SAFETY: + // - It is known that 'bytes.len() == N'. + // - Thus '&bytes' has the same layout as '[u8; N]'. + // - Thus it is safe to cast a pointer to it to '[u8; N]'. + // - The referenced data has the same lifetime as the output. + Ok(unsafe { &*bytes.as_ptr().cast::<[u8; N]>() }) + } + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + addr.cast() + } } //--- Carrying over 'zerocopy' traits From baaa8d2ad63ec90b634489aceff18a39615f7b74 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 13:13:11 +0100 Subject: [PATCH 054/167] [macros] Add module 'repr' for checking for stable layouts --- macros/src/lib.rs | 5 +++- macros/src/repr.rs | 68 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 macros/src/repr.rs diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 255567046..755467061 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -11,6 +11,9 @@ use syn::*; mod impls; use impls::ImplSkeleton; +mod repr; +use repr::Repr; + //----------- ParseBytesByRef ------------------------------------------------ #[proc_macro_derive(ParseBytesByRef)] @@ -32,7 +35,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { } }; - // TODO: Ensure that the type is 'repr(C)' or 'repr(transparent)'. + let _ = Repr::determine(&input.attrs)?; // Split up the last field from the rest. let mut fields = data.fields.iter(); diff --git a/macros/src/repr.rs b/macros/src/repr.rs new file mode 100644 index 000000000..428ef2a10 --- /dev/null +++ b/macros/src/repr.rs @@ -0,0 +1,68 @@ +//! Determining the memory layout of a type. + +use proc_macro2::Span; +use syn::{punctuated::Punctuated, *}; + +//----------- Repr ----------------------------------------------------------- + +/// The memory representation of a type. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum Repr { + /// Transparent to an underlying field. + Transparent, + + /// Compatible with C. + C, +} + +impl Repr { + /// Determine the representation for a type from its attributes. + /// + /// This will fail if a stable representation cannot be found. + pub fn determine(attrs: &[Attribute]) -> Result { + let mut repr = None; + for attr in attrs { + if !attr.path().is_ident("repr") { + continue; + } + + let nested = attr.parse_args_with( + Punctuated::::parse_terminated, + )?; + + // We don't check for consistency in the 'repr' attributes, since + // the compiler should be doing that for us anyway. This lets us + // ignore conflicting 'repr's entirely. + for meta in nested { + match meta { + Meta::Path(p) if p.is_ident("transparent") => { + repr = Some(Repr::Transparent); + } + + Meta::Path(p) if p.is_ident("C") => { + repr = Some(Repr::C); + } + + Meta::Path(p) if p.is_ident("Rust") => { + return Err(Error::new_spanned(p, + "repr(Rust) is not stable, cannot derive this for it")); + } + + meta => { + // We still need to error out here, in case a future + // version of Rust introduces more memory layout data + return Err(Error::new_spanned( + meta, + "unrecognized repr attribute", + )); + } + } + } + } + + repr.ok_or_else(|| { + Error::new(Span::call_site(), + "repr(C) or repr(transparent) must be specified to derive this") + }) + } +} From d4fc42f21e1c73e22fc01371456aa50a38567046 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 13:20:42 +0100 Subject: [PATCH 055/167] [new_base/parse] Implement '*BytesByRef' for '[T; N]' --- src/new_base/parse/mod.rs | 41 ++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index b218646c9..4c8917308 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -235,36 +235,33 @@ unsafe impl ParseBytesByRef for [u8] { } } -unsafe impl SplitBytesByRef for [u8; N] { +unsafe impl SplitBytesByRef for [T; N] { fn split_bytes_by_ref( - bytes: &[u8], + mut bytes: &[u8], ) -> Result<(&Self, &[u8]), ParseError> { - if bytes.len() < N { - Err(ParseError) - } else { - let (bytes, rest) = bytes.split_at(N); - - // SAFETY: - // - It is known that 'bytes.len() == N'. - // - Thus '&bytes' has the same layout as '[u8; N]'. - // - Thus it is safe to cast a pointer to it to '[u8; N]'. - // - The referenced data has the same lifetime as the output. - Ok((unsafe { &*bytes.as_ptr().cast::<[u8; N]>() }, rest)) + let start = bytes.as_ptr(); + for _ in 0..N { + (_, bytes) = T::split_bytes_by_ref(bytes)?; } + + // SAFETY: + // - 'T::split_bytes_by_ref()' was called 'N' times on successive + // positions, thus the original 'bytes' starts with 'N' instances + // of 'T' (even if 'T' is a ZST and so all instances overlap). + // - 'N' consecutive 'T's have the same layout as '[T; N]'. + // - Thus it is safe to cast 'start' to '[T; N]'. + // - The referenced data has the same lifetime as the output. + Ok((unsafe { &*start.cast::<[T; N]>() }, bytes)) } } -unsafe impl ParseBytesByRef for [u8; N] { +unsafe impl ParseBytesByRef for [T; N] { fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { - if bytes.len() != N { - Err(ParseError) + let (this, rest) = Self::split_bytes_by_ref(bytes)?; + if rest.is_empty() { + Ok(this) } else { - // SAFETY: - // - It is known that 'bytes.len() == N'. - // - Thus '&bytes' has the same layout as '[u8; N]'. - // - Thus it is safe to cast a pointer to it to '[u8; N]'. - // - The referenced data has the same lifetime as the output. - Ok(unsafe { &*bytes.as_ptr().cast::<[u8; N]>() }) + Err(ParseError) } } From a1bfc4f1042edb4f43f3f3ae054e70fa323cf630 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 13:49:00 +0100 Subject: [PATCH 056/167] [macros] Add a derive macro for 'SplitBytesByRef' 'ParseBytesByRef' now requires all implementing types to be unaligned. Otherwise, padding bytes wouldn't be accounted for properly, e.g. in // Has alignment of largest field: 8 bytes #[repr(C)] pub struct Foo { a: u8, // 7 bytes of padding here b: u64, } The 'derive' can't tell how much padding to use, so it would parse a '[u8; 9]' as a valid instance of 'Foo'. Every 'repr(C)' would have to use 'repr(packed)' too. --- macros/src/lib.rs | 121 +++++++++++++++++++++++++++++++++++++- macros/src/repr.rs | 23 +++++++- src/new_base/parse/mod.rs | 6 +- 3 files changed, 141 insertions(+), 9 deletions(-) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 755467061..8cb26183f 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -14,6 +14,113 @@ use impls::ImplSkeleton; mod repr; use repr::Repr; +//----------- SplitBytesByRef ------------------------------------------------ + +#[proc_macro_derive(SplitBytesByRef)] +pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: DeriveInput) -> Result { + let data = match &input.data { + Data::Struct(data) => data, + Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'SplitBytesByRef' can only be 'derive'd for 'struct's", + )); + } + Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'SplitBytesByRef' can only be 'derive'd for 'struct's", + )); + } + }; + + let _ = Repr::determine(&input.attrs, "SplitBytesByRef")?; + + // Split up the last field from the rest. + let mut fields = data.fields.iter(); + let Some(last) = fields.next_back() else { + // This type has no fields. Return a simple implementation. + let (impl_generics, ty_generics, where_clause) = + input.generics.split_for_impl(); + let name = input.ident; + + return Ok(quote! { + unsafe impl #impl_generics + ::domain::new_base::parse::SplitBytesByRef + for #name #ty_generics + #where_clause { + fn split_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + (&Self, &[::domain::__core::primitive::u8]), + ::domain::new_base::parse::ParseError, + > { + Ok(( + unsafe { &*bytes.as_ptr().cast::() }, + bytes, + )) + } + } + }); + }; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let bound = parse_quote!(::domain::new_base::parse::SplitBytesByRef); + let mut skeleton = ImplSkeleton::new(&input, true, bound); + + // Establish bounds on the fields. + for field in data.fields.iter() { + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::parse::SplitBytesByRef), + ); + } + + // Define 'split_bytes_by_ref()'. + let tys = fields.clone().map(|f| &f.ty); + let last_ty = &last.ty; + skeleton.contents.stmts.push(parse_quote! { + fn split_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + (&Self, &[::domain::__core::primitive::u8]), + ::domain::new_base::parse::ParseError, + > { + let start = bytes.as_ptr(); + #(let (_, bytes) = + <#tys as ::domain::new_base::parse::SplitBytesByRef> + ::split_bytes_by_ref(bytes)?;)* + let (last, rest) = + <#last_ty as ::domain::new_base::parse::SplitBytesByRef> + ::split_bytes_by_ref(bytes)?; + let ptr = + <#last_ty as ::domain::new_base::parse::ParseBytesByRef> + ::ptr_with_address(last, start as *const ()); + + // SAFETY: + // - The original 'bytes' contained a valid instance of every + // field in 'Self', in succession. + // - Every field implements 'ParseBytesByRef' and so has no + // alignment restriction. + // - 'Self' is unaligned, since every field is unaligned, and + // any explicit alignment modifiers only make it unaligned. + // - 'start' is thus the start of a valid instance of 'Self'. + // - 'ptr' has the same address as 'start' but can be cast to + // 'Self', since it has the right pointer metadata. + Ok((unsafe { &*(ptr as *const Self) }, rest)) + } + }); + + Ok(skeleton.into_token_stream().into()) + } + + let input = syn::parse_macro_input!(input as DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + //----------- ParseBytesByRef ------------------------------------------------ #[proc_macro_derive(ParseBytesByRef)] @@ -35,7 +142,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { } }; - let _ = Repr::determine(&input.attrs)?; + let _ = Repr::determine(&input.attrs, "ParseBytesByRef")?; // Split up the last field from the rest. let mut fields = data.fields.iter(); @@ -46,7 +153,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { let name = input.ident; return Ok(quote! { - impl #impl_generics + unsafe impl #impl_generics ::domain::new_base::parse::ParseBytesByRef for #name #ty_generics #where_clause { @@ -109,7 +216,15 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { ::ptr_with_address(last, start as *const ()); // SAFETY: - // - By + // - The original 'bytes' contained a valid instance of every + // field in 'Self', in succession. + // - Every field implements 'ParseBytesByRef' and so has no + // alignment restriction. + // - 'Self' is unaligned, since every field is unaligned, and + // any explicit alignment modifiers only make it unaligned. + // - 'start' is thus the start of a valid instance of 'Self'. + // - 'ptr' has the same address as 'start' but can be cast to + // 'Self', since it has the right pointer metadata. Ok(unsafe { &*(ptr as *const Self) }) } }); diff --git a/macros/src/repr.rs b/macros/src/repr.rs index 428ef2a10..80c900eb6 100644 --- a/macros/src/repr.rs +++ b/macros/src/repr.rs @@ -1,7 +1,7 @@ //! Determining the memory layout of a type. use proc_macro2::Span; -use syn::{punctuated::Punctuated, *}; +use syn::{punctuated::Punctuated, spanned::Spanned, *}; //----------- Repr ----------------------------------------------------------- @@ -19,7 +19,7 @@ impl Repr { /// Determine the representation for a type from its attributes. /// /// This will fail if a stable representation cannot be found. - pub fn determine(attrs: &[Attribute]) -> Result { + pub fn determine(attrs: &[Attribute], bound: &str) -> Result { let mut repr = None; for attr in attrs { if !attr.path().is_ident("repr") { @@ -45,7 +45,24 @@ impl Repr { Meta::Path(p) if p.is_ident("Rust") => { return Err(Error::new_spanned(p, - "repr(Rust) is not stable, cannot derive this for it")); + format!("repr(Rust) is not stable, cannot derive {bound} for it"))); + } + + Meta::Path(p) if p.is_ident("packed") => { + // The alignment can be set to 1 safely. + } + + Meta::List(meta) + if meta.path.is_ident("packed") + || meta.path.is_ident("aligned") => + { + let span = meta.span(); + let lit: LitInt = parse2(meta.tokens)?; + let n: usize = lit.base10_parse()?; + if n != 1 { + return Err(Error::new(span, + format!("'Self' must be unaligned to derive {bound}"))); + } } meta => { diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 4c8917308..eb4815d04 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -139,9 +139,9 @@ pub unsafe trait SplitBytesByRef: ParseBytesByRef { /// [`parse_bytes_by_ref()`]: Self::parse_bytes_by_ref() /// [`ptr_with_address()`]: Self::ptr_with_address() /// -/// Implementing types should almost always be unaligned, but foregoing this -/// will not cause undefined behaviour (however, it will be very confusing for -/// users). +/// Implementing types must also have no alignment (i.e. a valid instance of +/// [`Self`] can occur at any address). This eliminates the possibility of +/// padding bytes, even when [`Self`] is part of a larger aggregate type. pub unsafe trait ParseBytesByRef { /// Interpret a byte string as an instance of [`Self`]. /// From 2164b8757b5aef89e5cbc6980f0f0c6449cf31b8 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 14:08:43 +0100 Subject: [PATCH 057/167] Use '{Parse,Split}BytesByRef' instead of 'zerocopy' --- src/new_base/parse/mod.rs | 95 +++++++++++++++++++++++---------------- src/new_base/question.rs | 11 +++-- src/new_base/record.rs | 28 ++++++------ src/new_base/serial.rs | 6 +-- src/new_edns/mod.rs | 13 +++--- src/new_rdata/basic.rs | 8 ++-- src/new_rdata/ipv6.rs | 6 +-- src/new_rdata/mod.rs | 3 +- 8 files changed, 92 insertions(+), 78 deletions(-) diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index eb4815d04..c5faf0440 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -2,7 +2,10 @@ use core::{fmt, ops::Range}; -use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; +use zerocopy::{ + network_endian::{U16, U32}, + FromBytes, IntoBytes, +}; mod message; pub use message::{MessagePart, ParseMessage, VisitMessagePart}; @@ -43,37 +46,26 @@ pub trait ParseFromMessage<'a>: Sized { ) -> Result; } -//--- Carrying over 'zerocopy' traits - -// NOTE: We can't carry over 'read_from_prefix' because the trait impls would -// conflict. We kept 'ref_from_prefix' since it's more general. - -impl<'a, T: ?Sized> SplitFromMessage<'a> for &'a T -where - T: FromBytes + KnownLayout + Immutable, -{ +impl<'a, T: ?Sized + SplitBytesByRef> SplitFromMessage<'a> for &'a T { fn split_from_message( message: &'a Message, start: usize, ) -> Result<(Self, usize), ParseError> { let message = message.as_bytes(); let bytes = message.get(start..).ok_or(ParseError)?; - let (this, rest) = T::ref_from_prefix(bytes)?; + let (this, rest) = T::split_bytes_by_ref(bytes)?; Ok((this, message.len() - rest.len())) } } -impl<'a, T: ?Sized> ParseFromMessage<'a> for &'a T -where - T: FromBytes + KnownLayout + Immutable, -{ +impl<'a, T: ?Sized + ParseBytesByRef> ParseFromMessage<'a> for &'a T { fn parse_from_message( message: &'a Message, range: Range, ) -> Result { let message = message.as_bytes(); let bytes = message.get(range).ok_or(ParseError)?; - Ok(T::ref_from_bytes(bytes)?) + T::parse_bytes_by_ref(bytes) } } @@ -97,6 +89,18 @@ pub trait ParseFrom<'a>: Sized { fn parse_from(bytes: &'a [u8]) -> Result; } +impl<'a, T: ?Sized + SplitBytesByRef> SplitFrom<'a> for &'a T { + fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + T::split_bytes_by_ref(bytes).map_err(|_| ParseError) + } +} + +impl<'a, T: ?Sized + ParseBytesByRef> ParseFrom<'a> for &'a T { + fn parse_from(bytes: &'a [u8]) -> Result { + T::parse_bytes_by_ref(bytes).map_err(|_| ParseError) + } +} + /// Zero-copy parsing from the start of a byte string. /// /// # Safety @@ -225,6 +229,42 @@ unsafe impl ParseBytesByRef for u8 { } } +unsafe impl SplitBytesByRef for U16 { + fn split_bytes_by_ref( + bytes: &[u8], + ) -> Result<(&Self, &[u8]), ParseError> { + Self::ref_from_prefix(bytes).map_err(Into::into) + } +} + +unsafe impl ParseBytesByRef for U16 { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + Self::ref_from_bytes(bytes).map_err(Into::into) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + addr.cast() + } +} + +unsafe impl SplitBytesByRef for U32 { + fn split_bytes_by_ref( + bytes: &[u8], + ) -> Result<(&Self, &[u8]), ParseError> { + Self::ref_from_prefix(bytes).map_err(Into::into) + } +} + +unsafe impl ParseBytesByRef for U32 { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + Self::ref_from_bytes(bytes).map_err(Into::into) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + addr.cast() + } +} + unsafe impl ParseBytesByRef for [u8] { fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { Ok(bytes) @@ -270,29 +310,6 @@ unsafe impl ParseBytesByRef for [T; N] { } } -//--- Carrying over 'zerocopy' traits - -// NOTE: We can't carry over 'read_from_prefix' because the trait impls would -// conflict. We kept 'ref_from_prefix' since it's more general. - -impl<'a, T: ?Sized> SplitFrom<'a> for &'a T -where - T: FromBytes + KnownLayout + Immutable, -{ - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - T::ref_from_prefix(bytes).map_err(|_| ParseError) - } -} - -impl<'a, T: ?Sized> ParseFrom<'a> for &'a T -where - T: FromBytes + KnownLayout + Immutable, -{ - fn parse_from(bytes: &'a [u8]) -> Result { - T::ref_from_bytes(bytes).map_err(|_| ParseError) - } -} - //----------- ParseError ----------------------------------------------------- /// A DNS message parsing error. diff --git a/src/new_base/question.rs b/src/new_base/question.rs index f173a664f..81411aaf0 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -2,6 +2,7 @@ use core::ops::Range; +use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::{network_endian::U16, IntoBytes}; use zerocopy_derive::*; @@ -150,11 +151,10 @@ where PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct QType { @@ -174,11 +174,10 @@ pub struct QType { PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct QClass { diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 6354611ed..36d4c58dd 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -5,6 +5,7 @@ use core::{ ops::{Deref, Range}, }; +use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::{ network_endian::{U16, U32}, FromBytes, IntoBytes, SizeError, @@ -156,9 +157,9 @@ where { fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let (rname, rest) = N::split_from(bytes)?; - let (rtype, rest) = RType::read_from_prefix(rest)?; - let (rclass, rest) = RClass::read_from_prefix(rest)?; - let (ttl, rest) = TTL::read_from_prefix(rest)?; + let (&rtype, rest) = <&RType>::split_from(rest)?; + let (&rclass, rest) = <&RClass>::split_from(rest)?; + let (&ttl, rest) = <&TTL>::split_from(rest)?; let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; @@ -175,9 +176,9 @@ where { fn parse_from(bytes: &'a [u8]) -> Result { let (rname, rest) = N::split_from(bytes)?; - let (rtype, rest) = RType::read_from_prefix(rest)?; - let (rclass, rest) = RClass::read_from_prefix(rest)?; - let (ttl, rest) = TTL::read_from_prefix(rest)?; + let (&rtype, rest) = <&RType>::split_from(rest)?; + let (&rclass, rest) = <&RClass>::split_from(rest)?; + let (&ttl, rest) = <&TTL>::split_from(rest)?; let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; @@ -228,11 +229,10 @@ where PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct RType { @@ -286,11 +286,10 @@ impl RType { PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct RClass { @@ -310,11 +309,10 @@ pub struct RClass { PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct TTL { diff --git a/src/new_base/serial.rs b/src/new_base/serial.rs index fe00923c3..f351e1a46 100644 --- a/src/new_base/serial.rs +++ b/src/new_base/serial.rs @@ -8,6 +8,7 @@ use core::{ ops::{Add, AddAssign}, }; +use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::network_endian::U32; use zerocopy_derive::*; @@ -21,11 +22,10 @@ use zerocopy_derive::*; PartialEq, Eq, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct Serial(U32); diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index f4d12f9c1..f15132d07 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -4,6 +4,7 @@ use core::{fmt, ops::Range}; +use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::{network_endian::U16, FromBytes, IntoBytes}; use zerocopy_derive::*; @@ -128,11 +129,10 @@ impl<'a> ParseFrom<'a> for EdnsRecord<'a> { Clone, Default, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct EdnsFlags { @@ -207,11 +207,10 @@ pub enum EdnsOption<'b> { PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct OptionCode { @@ -222,7 +221,7 @@ pub struct OptionCode { //----------- UnknownOption -------------------------------------------------- /// Data for an unknown Extended DNS option. -#[derive(Debug, FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[derive(Debug, IntoBytes, Immutable, ParseBytesByRef)] #[repr(C)] pub struct UnknownOption { /// The unparsed option data. diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 4807784b3..1b5c0baeb 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -10,6 +10,7 @@ use core::str::FromStr; #[cfg(feature = "std")] use std::net::Ipv4Addr; +use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::{ network_endian::{U16, U32}, IntoBytes, @@ -36,11 +37,10 @@ use crate::new_base::{ PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct A { @@ -328,7 +328,7 @@ impl BuildInto for Soa { //----------- Wks ------------------------------------------------------------ /// Well-known services supported on this domain. -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[derive(IntoBytes, Immutable, ParseBytesByRef)] #[repr(C, packed)] pub struct Wks { /// The address of the host providing these services. diff --git a/src/new_rdata/ipv6.rs b/src/new_rdata/ipv6.rs index 606486d08..fdb2aa674 100644 --- a/src/new_rdata/ipv6.rs +++ b/src/new_rdata/ipv6.rs @@ -8,6 +8,7 @@ use core::{fmt, str::FromStr}; #[cfg(feature = "std")] use std::net::Ipv6Addr; +use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::IntoBytes; use zerocopy_derive::*; @@ -27,11 +28,10 @@ use crate::new_base::build::{ PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, + SplitBytesByRef, )] #[repr(transparent)] pub struct Aaaa { diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 3228608cc..afc4820ae 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -2,6 +2,7 @@ use core::ops::Range; +use domain_macros::ParseBytesByRef; use zerocopy_derive::*; use crate::new_base::{ @@ -172,7 +173,7 @@ impl BuildInto for RecordData<'_, N> { //----------- UnknownRecordData ---------------------------------------------- /// Data for an unknown DNS record type. -#[derive(Debug, FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[derive(Debug, IntoBytes, Immutable, ParseBytesByRef)] #[repr(C)] pub struct UnknownRecordData { /// The unparsed option data. From 499e858929c965fec60b233f068ab2a0a6ee08b8 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 15:37:30 +0100 Subject: [PATCH 058/167] Rename '{Split,Parse}From' to '{Split,Parse}Bytes' --- src/new_base/charstr.rs | 15 ++++---- src/new_base/name/label.rs | 14 +++---- src/new_base/name/reversed.rs | 10 ++--- src/new_base/parse/mod.rs | 68 +++++++++++++++++++++++++++------ src/new_base/question.rs | 26 ++++++------- src/new_base/record.rs | 31 +++++++-------- src/new_edns/mod.rs | 34 ++++++++--------- src/new_rdata/basic.rs | 71 ++++++++++++++++++----------------- src/new_rdata/mod.rs | 27 ++++++------- 9 files changed, 172 insertions(+), 124 deletions(-) diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index fdd5e5bdf..57f888c27 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -8,7 +8,8 @@ use zerocopy_derive::*; use super::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, parse::{ - ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + ParseBytes, ParseError, ParseFromMessage, SplitBytes, + SplitFromMessage, }, Message, }; @@ -31,7 +32,7 @@ impl<'a> SplitFromMessage<'a> for &'a CharStr { start: usize, ) -> Result<(Self, usize), ParseError> { let bytes = &message.as_bytes()[start..]; - let (this, rest) = Self::split_from(bytes)?; + let (this, rest) = Self::split_bytes(bytes)?; Ok((this, bytes.len() - rest.len())) } } @@ -45,7 +46,7 @@ impl<'a> ParseFromMessage<'a> for &'a CharStr { .as_bytes() .get(range) .ok_or(ParseError) - .and_then(Self::parse_from) + .and_then(Self::parse_bytes) } } @@ -65,8 +66,8 @@ impl BuildIntoMessage for CharStr { //--- Parsing from bytes -impl<'a> SplitFrom<'a> for &'a CharStr { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { +impl<'a> SplitBytes<'a> for &'a CharStr { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let (&length, rest) = bytes.split_first().ok_or(ParseError)?; if length as usize > rest.len() { return Err(ParseError); @@ -78,8 +79,8 @@ impl<'a> SplitFrom<'a> for &'a CharStr { } } -impl<'a> ParseFrom<'a> for &'a CharStr { - fn parse_from(bytes: &'a [u8]) -> Result { +impl<'a> ParseBytes<'a> for &'a CharStr { + fn parse_bytes(bytes: &'a [u8]) -> Result { let (&length, rest) = bytes.split_first().ok_or(ParseError)?; if length as usize != rest.len() { return Err(ParseError); diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index 087692047..7068e2e15 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -9,7 +9,7 @@ use core::{ use zerocopy_derive::*; -use crate::new_base::parse::{ParseError, ParseFrom, SplitFrom}; +use crate::new_base::parse::{ParseError, ParseBytes, SplitBytes}; //----------- Label ---------------------------------------------------------- @@ -52,8 +52,8 @@ impl Label { //--- Parsing -impl<'a> SplitFrom<'a> for &'a Label { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { +impl<'a> SplitBytes<'a> for &'a Label { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let (&size, rest) = bytes.split_first().ok_or(ParseError)?; if size < 64 && rest.len() >= size as usize { let (label, rest) = bytes.split_at(1 + size as usize); @@ -65,9 +65,9 @@ impl<'a> SplitFrom<'a> for &'a Label { } } -impl<'a> ParseFrom<'a> for &'a Label { - fn parse_from(bytes: &'a [u8]) -> Result { - Self::split_from(bytes).and_then(|(this, rest)| { +impl<'a> ParseBytes<'a> for &'a Label { + fn parse_bytes(bytes: &'a [u8]) -> Result { + Self::split_bytes(bytes).and_then(|(this, rest)| { rest.is_empty().then_some(this).ok_or(ParseError) }) } @@ -254,7 +254,7 @@ impl<'a> Iterator for LabelIter<'a> { // SAFETY: 'bytes' is assumed to only contain valid labels. let (head, tail) = - unsafe { <&Label>::split_from(self.bytes).unwrap_unchecked() }; + unsafe { <&Label>::split_bytes(self.bytes).unwrap_unchecked() }; self.bytes = tail; Some(head) } diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 513a72582..ee7b73b9e 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -14,7 +14,7 @@ use zerocopy_derive::*; use crate::new_base::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, parse::{ - ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + ParseError, ParseBytes, ParseFromMessage, SplitBytes, SplitFromMessage, }, Message, }; @@ -385,8 +385,8 @@ impl BuildIntoMessage for RevNameBuf { //--- Parsing from bytes -impl<'a> SplitFrom<'a> for RevNameBuf { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { +impl<'a> SplitBytes<'a> for RevNameBuf { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let mut buffer = Self::empty(); let (pointer, rest) = parse_segment(bytes, &mut buffer)?; @@ -401,8 +401,8 @@ impl<'a> SplitFrom<'a> for RevNameBuf { } } -impl<'a> ParseFrom<'a> for RevNameBuf { - fn parse_from(bytes: &'a [u8]) -> Result { +impl<'a> ParseBytes<'a> for RevNameBuf { + fn parse_bytes(bytes: &'a [u8]) -> Result { let mut buffer = Self::empty(); let (pointer, rest) = parse_segment(bytes, &mut buffer)?; diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index c5faf0440..493542b66 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -72,46 +72,90 @@ impl<'a, T: ?Sized + ParseBytesByRef> ParseFromMessage<'a> for &'a T { //----------- Low-level parsing traits --------------------------------------- /// Parsing from the start of a byte string. -pub trait SplitFrom<'a>: Sized + ParseFrom<'a> { +pub trait SplitBytes<'a>: Sized + ParseBytes<'a> { /// Parse a value of [`Self`] from the start of the byte string. /// /// If parsing is successful, the parsed value and the rest of the string /// are returned. Otherwise, a [`ParseError`] is returned. - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError>; + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError>; } /// Parsing from a byte string. -pub trait ParseFrom<'a>: Sized { +pub trait ParseBytes<'a>: Sized { /// Parse a value of [`Self`] from the given byte string. /// /// If parsing is successful, the parsed value is returned. Otherwise, a /// [`ParseError`] is returned. - fn parse_from(bytes: &'a [u8]) -> Result; + fn parse_bytes(bytes: &'a [u8]) -> Result; } -impl<'a, T: ?Sized + SplitBytesByRef> SplitFrom<'a> for &'a T { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { +impl<'a, T: ?Sized + SplitBytesByRef> SplitBytes<'a> for &'a T { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { T::split_bytes_by_ref(bytes).map_err(|_| ParseError) } } -impl<'a, T: ?Sized + ParseBytesByRef> ParseFrom<'a> for &'a T { - fn parse_from(bytes: &'a [u8]) -> Result { +impl<'a, T: ?Sized + ParseBytesByRef> ParseBytes<'a> for &'a T { + fn parse_bytes(bytes: &'a [u8]) -> Result { T::parse_bytes_by_ref(bytes).map_err(|_| ParseError) } } +impl<'a> SplitBytes<'a> for u8 { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + bytes.split_first().map(|(&f, r)| (f, r)).ok_or(ParseError) + } +} + +impl<'a> ParseBytes<'a> for u8 { + fn parse_bytes(bytes: &'a [u8]) -> Result { + let [result] = bytes else { + return Err(ParseError); + }; + + Ok(*result) + } +} + +impl<'a> SplitBytes<'a> for U16 { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + Self::read_from_prefix(bytes).map_err(Into::into) + } +} + +impl<'a> ParseBytes<'a> for U16 { + fn parse_bytes(bytes: &'a [u8]) -> Result { + Self::read_from_bytes(bytes).map_err(Into::into) + } +} + +impl<'a> SplitBytes<'a> for U32 { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + Self::read_from_prefix(bytes).map_err(Into::into) + } +} + +impl<'a> ParseBytes<'a> for U32 { + fn parse_bytes(bytes: &'a [u8]) -> Result { + Self::read_from_bytes(bytes).map_err(Into::into) + } +} + /// Zero-copy parsing from the start of a byte string. /// +/// This is an extension of [`ParseBytesByRef`] for types which can determine +/// their own length when parsing. It is usually implemented by [`Sized`] +/// types (where the length is just the size of the type), although it can be +/// sometimes implemented by unsized types. +/// /// # Safety /// /// Every implementation of [`SplitBytesByRef`] must satisfy the invariants /// documented on [`split_bytes_by_ref()`]. An incorrect implementation is /// considered to cause undefined behaviour. /// -/// Implementing types should almost always be unaligned, but foregoing this -/// will not cause undefined behaviour (however, it will be very confusing for -/// users). +/// Note that [`ParseBytesByRef`], required by this trait, also has several +/// invariants that need to be considered with care. pub unsafe trait SplitBytesByRef: ParseBytesByRef { /// Interpret a byte string as an instance of [`Self`]. /// @@ -221,7 +265,7 @@ unsafe impl ParseBytesByRef for u8 { return Err(ParseError); }; - return Ok(result); + Ok(result) } fn ptr_with_address(&self, addr: *const ()) -> *const Self { diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 81411aaf0..029f2839f 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -10,7 +10,7 @@ use super::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, name::RevNameBuf, parse::{ - ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + ParseError, ParseBytes, ParseFromMessage, SplitBytes, SplitFromMessage, }, Message, }; @@ -98,26 +98,26 @@ where //--- Parsing from bytes -impl<'a, N> SplitFrom<'a> for Question +impl<'a, N> SplitBytes<'a> for Question where - N: SplitFrom<'a>, + N: SplitBytes<'a>, { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - let (qname, rest) = N::split_from(bytes)?; - let (&qtype, rest) = <&QType>::split_from(rest)?; - let (&qclass, rest) = <&QClass>::split_from(rest)?; + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (qname, rest) = N::split_bytes(bytes)?; + let (&qtype, rest) = <&QType>::split_bytes(rest)?; + let (&qclass, rest) = <&QClass>::split_bytes(rest)?; Ok((Self::new(qname, qtype, qclass), rest)) } } -impl<'a, N> ParseFrom<'a> for Question +impl<'a, N> ParseBytes<'a> for Question where - N: SplitFrom<'a>, + N: SplitBytes<'a>, { - fn parse_from(bytes: &'a [u8]) -> Result { - let (qname, rest) = N::split_from(bytes)?; - let (&qtype, rest) = <&QType>::split_from(rest)?; - let &qclass = <&QClass>::parse_from(rest)?; + fn parse_bytes(bytes: &'a [u8]) -> Result { + let (qname, rest) = N::split_bytes(bytes)?; + let (&qtype, rest) = <&QType>::split_bytes(rest)?; + let &qclass = <&QClass>::parse_bytes(rest)?; Ok(Self::new(qname, qtype, qclass)) } } diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 36d4c58dd..0b3bab85b 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -16,7 +16,8 @@ use super::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, name::RevNameBuf, parse::{ - ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + ParseBytes, ParseError, ParseFromMessage, SplitBytes, + SplitFromMessage, }, Message, }; @@ -150,16 +151,16 @@ where //--- Parsing from bytes -impl<'a, N, D> SplitFrom<'a> for Record +impl<'a, N, D> SplitBytes<'a> for Record where - N: SplitFrom<'a>, + N: SplitBytes<'a>, D: ParseRecordData<'a>, { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - let (rname, rest) = N::split_from(bytes)?; - let (&rtype, rest) = <&RType>::split_from(rest)?; - let (&rclass, rest) = <&RClass>::split_from(rest)?; - let (&ttl, rest) = <&TTL>::split_from(rest)?; + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + let (rname, rest) = N::split_bytes(bytes)?; + let (&rtype, rest) = <&RType>::split_bytes(rest)?; + let (&rclass, rest) = <&RClass>::split_bytes(rest)?; + let (&ttl, rest) = <&TTL>::split_bytes(rest)?; let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; @@ -169,16 +170,16 @@ where } } -impl<'a, N, D> ParseFrom<'a> for Record +impl<'a, N, D> ParseBytes<'a> for Record where - N: SplitFrom<'a>, + N: SplitBytes<'a>, D: ParseRecordData<'a>, { - fn parse_from(bytes: &'a [u8]) -> Result { - let (rname, rest) = N::split_from(bytes)?; - let (&rtype, rest) = <&RType>::split_from(rest)?; - let (&rclass, rest) = <&RClass>::split_from(rest)?; - let (&ttl, rest) = <&TTL>::split_from(rest)?; + fn parse_bytes(bytes: &'a [u8]) -> Result { + let (rname, rest) = N::split_bytes(bytes)?; + let (&rtype, rest) = <&RType>::split_bytes(rest)?; + let (&rclass, rest) = <&RClass>::split_bytes(rest)?; + let (&ttl, rest) = <&TTL>::split_bytes(rest)?; let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index f15132d07..8f9c7de65 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -11,7 +11,7 @@ use zerocopy_derive::*; use crate::{ new_base::{ parse::{ - ParseError, ParseFrom, ParseFromMessage, SplitFrom, + ParseError, ParseBytes, ParseFromMessage, SplitBytes, SplitFromMessage, }, Message, @@ -48,7 +48,7 @@ impl<'a> SplitFromMessage<'a> for EdnsRecord<'a> { start: usize, ) -> Result<(Self, usize), ParseError> { let bytes = message.as_bytes().get(start..).ok_or(ParseError)?; - let (this, rest) = Self::split_from(bytes)?; + let (this, rest) = Self::split_bytes(bytes)?; Ok((this, message.as_bytes().len() - rest.len())) } } @@ -62,24 +62,24 @@ impl<'a> ParseFromMessage<'a> for EdnsRecord<'a> { .as_bytes() .get(range) .ok_or(ParseError) - .and_then(Self::parse_from) + .and_then(Self::parse_bytes) } } //--- Parsing from bytes -impl<'a> SplitFrom<'a> for EdnsRecord<'a> { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { +impl<'a> SplitBytes<'a> for EdnsRecord<'a> { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { // Strip the record name (root) and the record type. let rest = bytes.strip_prefix(&[0, 0, 41]).ok_or(ParseError)?; - let (&max_udp_payload, rest) = <&U16>::split_from(rest)?; - let (&ext_rcode, rest) = <&u8>::split_from(rest)?; - let (&version, rest) = <&u8>::split_from(rest)?; - let (&flags, rest) = <&EdnsFlags>::split_from(rest)?; + let (&max_udp_payload, rest) = <&U16>::split_bytes(rest)?; + let (&ext_rcode, rest) = <&u8>::split_bytes(rest)?; + let (&version, rest) = <&u8>::split_bytes(rest)?; + let (&flags, rest) = <&EdnsFlags>::split_bytes(rest)?; // Split the record size and data. - let (&size, rest) = <&U16>::split_from(rest)?; + let (&size, rest) = <&U16>::split_bytes(rest)?; let size: usize = size.get().into(); let (options, rest) = Opt::ref_from_prefix_with_elems(rest, size)?; @@ -96,18 +96,18 @@ impl<'a> SplitFrom<'a> for EdnsRecord<'a> { } } -impl<'a> ParseFrom<'a> for EdnsRecord<'a> { - fn parse_from(bytes: &'a [u8]) -> Result { +impl<'a> ParseBytes<'a> for EdnsRecord<'a> { + fn parse_bytes(bytes: &'a [u8]) -> Result { // Strip the record name (root) and the record type. let rest = bytes.strip_prefix(&[0, 0, 41]).ok_or(ParseError)?; - let (&max_udp_payload, rest) = <&U16>::split_from(rest)?; - let (&ext_rcode, rest) = <&u8>::split_from(rest)?; - let (&version, rest) = <&u8>::split_from(rest)?; - let (&flags, rest) = <&EdnsFlags>::split_from(rest)?; + let (&max_udp_payload, rest) = <&U16>::split_bytes(rest)?; + let (&ext_rcode, rest) = <&u8>::split_bytes(rest)?; + let (&version, rest) = <&u8>::split_bytes(rest)?; + let (&flags, rest) = <&EdnsFlags>::split_bytes(rest)?; // Split the record size and data. - let (&size, rest) = <&U16>::split_from(rest)?; + let (&size, rest) = <&U16>::split_bytes(rest)?; let size: usize = size.get().into(); let options = Opt::ref_from_bytes_with_elems(rest, size)?; diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 1b5c0baeb..bfb11b9de 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -20,7 +20,8 @@ use zerocopy_derive::*; use crate::new_base::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, parse::{ - ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + ParseBytes, ParseError, ParseFromMessage, SplitBytes, + SplitFromMessage, }, CharStr, Message, Serial, }; @@ -142,9 +143,9 @@ impl BuildIntoMessage for Ns { //--- Parsing from bytes -impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ns { - fn parse_from(bytes: &'a [u8]) -> Result { - N::parse_from(bytes).map(|name| Self { name }) +impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for Ns { + fn parse_bytes(bytes: &'a [u8]) -> Result { + N::parse_bytes(bytes).map(|name| Self { name }) } } @@ -193,9 +194,9 @@ impl BuildIntoMessage for CName { //--- Parsing from bytes -impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for CName { - fn parse_from(bytes: &'a [u8]) -> Result { - N::parse_from(bytes).map(|name| Self { name }) +impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for CName { + fn parse_bytes(bytes: &'a [u8]) -> Result { + N::parse_bytes(bytes).map(|name| Self { name }) } } @@ -285,15 +286,15 @@ impl BuildIntoMessage for Soa { //--- Parsing from bytes -impl<'a, N: SplitFrom<'a>> ParseFrom<'a> for Soa { - fn parse_from(bytes: &'a [u8]) -> Result { - let (mname, rest) = N::split_from(bytes)?; - let (rname, rest) = N::split_from(rest)?; - let (&serial, rest) = <&Serial>::split_from(rest)?; - let (&refresh, rest) = <&U32>::split_from(rest)?; - let (&retry, rest) = <&U32>::split_from(rest)?; - let (&expire, rest) = <&U32>::split_from(rest)?; - let &minimum = <&U32>::parse_from(rest)?; +impl<'a, N: SplitBytes<'a>> ParseBytes<'a> for Soa { + fn parse_bytes(bytes: &'a [u8]) -> Result { + let (mname, rest) = N::split_bytes(bytes)?; + let (rname, rest) = N::split_bytes(rest)?; + let (&serial, rest) = <&Serial>::split_bytes(rest)?; + let (&refresh, rest) = <&U32>::split_bytes(rest)?; + let (&retry, rest) = <&U32>::split_bytes(rest)?; + let (&expire, rest) = <&U32>::split_bytes(rest)?; + let &minimum = <&U32>::parse_bytes(rest)?; Ok(Self { mname, @@ -425,9 +426,9 @@ impl BuildIntoMessage for Ptr { //--- Parsing from bytes -impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Ptr { - fn parse_from(bytes: &'a [u8]) -> Result { - N::parse_from(bytes).map(|name| Self { name }) +impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for Ptr { + fn parse_bytes(bytes: &'a [u8]) -> Result { + N::parse_bytes(bytes).map(|name| Self { name }) } } @@ -465,7 +466,7 @@ impl<'a> ParseFromMessage<'a> for HInfo<'a> { .as_bytes() .get(range) .ok_or(ParseError) - .and_then(Self::parse_from) + .and_then(Self::parse_bytes) } } @@ -485,10 +486,10 @@ impl BuildIntoMessage for HInfo<'_> { //--- Parsing from bytes -impl<'a> ParseFrom<'a> for HInfo<'a> { - fn parse_from(bytes: &'a [u8]) -> Result { - let (cpu, rest) = <&CharStr>::split_from(bytes)?; - let os = <&CharStr>::parse_from(rest)?; +impl<'a> ParseBytes<'a> for HInfo<'a> { + fn parse_bytes(bytes: &'a [u8]) -> Result { + let (cpu, rest) = <&CharStr>::split_bytes(bytes)?; + let os = <&CharStr>::parse_bytes(rest)?; Ok(Self { cpu, os }) } } @@ -552,10 +553,10 @@ impl BuildIntoMessage for Mx { //--- Parsing from bytes -impl<'a, N: ParseFrom<'a>> ParseFrom<'a> for Mx { - fn parse_from(bytes: &'a [u8]) -> Result { - let (&preference, rest) = <&U16>::split_from(bytes)?; - let exchange = N::parse_from(rest)?; +impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for Mx { + fn parse_bytes(bytes: &'a [u8]) -> Result { + let (&preference, rest) = <&U16>::split_bytes(bytes)?; + let exchange = N::parse_bytes(rest)?; Ok(Self { preference, exchange, @@ -594,11 +595,11 @@ impl Txt { &self, ) -> impl Iterator> + '_ { // NOTE: A TXT record always has at least one 'CharStr' within. - let first = <&CharStr>::split_from(&self.content); + let first = <&CharStr>::split_bytes(&self.content); core::iter::successors(Some(first), |prev| { prev.as_ref() .ok() - .map(|(_elem, rest)| <&CharStr>::split_from(rest)) + .map(|(_elem, rest)| <&CharStr>::split_bytes(rest)) }) .map(|result| result.map(|(elem, _rest)| elem)) } @@ -615,7 +616,7 @@ impl<'a> ParseFromMessage<'a> for &'a Txt { .as_bytes() .get(range) .ok_or(ParseError) - .and_then(Self::parse_from) + .and_then(Self::parse_bytes) } } @@ -632,12 +633,12 @@ impl BuildIntoMessage for Txt { //--- Parsing from bytes -impl<'a> ParseFrom<'a> for &'a Txt { - fn parse_from(bytes: &'a [u8]) -> Result { +impl<'a> ParseBytes<'a> for &'a Txt { + fn parse_bytes(bytes: &'a [u8]) -> Result { // NOTE: The input must contain at least one 'CharStr'. - let (_, mut rest) = <&CharStr>::split_from(bytes)?; + let (_, mut rest) = <&CharStr>::split_bytes(bytes)?; while !rest.is_empty() { - (_, rest) = <&CharStr>::split_from(rest)?; + (_, rest) = <&CharStr>::split_bytes(rest)?; } // SAFETY: 'Txt' is 'repr(transparent)' to '[u8]'. diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index afc4820ae..0cf020988 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -8,7 +8,8 @@ use zerocopy_derive::*; use crate::new_base::{ build::{BuildInto, BuildIntoMessage, Builder, TruncationError}, parse::{ - ParseError, ParseFrom, ParseFromMessage, SplitFrom, SplitFromMessage, + ParseBytes, ParseError, ParseFromMessage, SplitBytes, + SplitFromMessage, }, Message, ParseRecordData, RType, }; @@ -68,7 +69,7 @@ pub enum RecordData<'a, N> { impl<'a, N> ParseRecordData<'a> for RecordData<'a, N> where - N: SplitFrom<'a> + SplitFromMessage<'a>, + N: SplitBytes<'a> + SplitFromMessage<'a>, { fn parse_record_data( message: &'a Message, @@ -110,17 +111,17 @@ where rtype: RType, ) -> Result { match rtype { - RType::A => <&A>::parse_from(bytes).map(Self::A), - RType::NS => Ns::parse_from(bytes).map(Self::Ns), - RType::CNAME => CName::parse_from(bytes).map(Self::CName), - RType::SOA => Soa::parse_from(bytes).map(Self::Soa), - RType::WKS => <&Wks>::parse_from(bytes).map(Self::Wks), - RType::PTR => Ptr::parse_from(bytes).map(Self::Ptr), - RType::HINFO => HInfo::parse_from(bytes).map(Self::HInfo), - RType::MX => Mx::parse_from(bytes).map(Self::Mx), - RType::TXT => <&Txt>::parse_from(bytes).map(Self::Txt), - RType::AAAA => <&Aaaa>::parse_from(bytes).map(Self::Aaaa), - _ => <&UnknownRecordData>::parse_from(bytes) + RType::A => <&A>::parse_bytes(bytes).map(Self::A), + RType::NS => Ns::parse_bytes(bytes).map(Self::Ns), + RType::CNAME => CName::parse_bytes(bytes).map(Self::CName), + RType::SOA => Soa::parse_bytes(bytes).map(Self::Soa), + RType::WKS => <&Wks>::parse_bytes(bytes).map(Self::Wks), + RType::PTR => Ptr::parse_bytes(bytes).map(Self::Ptr), + RType::HINFO => HInfo::parse_bytes(bytes).map(Self::HInfo), + RType::MX => Mx::parse_bytes(bytes).map(Self::Mx), + RType::TXT => <&Txt>::parse_bytes(bytes).map(Self::Txt), + RType::AAAA => <&Aaaa>::parse_bytes(bytes).map(Self::Aaaa), + _ => <&UnknownRecordData>::parse_bytes(bytes) .map(|data| Self::Unknown(rtype, data)), } } From 2d60092169837f60ea32c1e0f12355389433a1b6 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 15:37:51 +0100 Subject: [PATCH 059/167] [macros] Add derives for '{Split,Parse}Bytes' --- macros/src/impls.rs | 77 ++++++------ macros/src/lib.rs | 282 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 315 insertions(+), 44 deletions(-) diff --git a/macros/src/impls.rs b/macros/src/impls.rs index c5af737fb..e72ef91ec 100644 --- a/macros/src/impls.rs +++ b/macros/src/impls.rs @@ -143,7 +143,16 @@ impl ImplSkeleton { /// If the type is concrete, a verifying statement is added for it. /// Otherwise, it is added to the where clause. pub fn require_bound(&mut self, target: Type, bound: TypeParamBound) { - if self.is_concrete(&target) { + let mut visitor = ConcretenessVisitor { + skeleton: self, + is_concrete: true, + }; + + // Concreteness applies to both the type and the bound. + visitor.visit_type(&target); + visitor.visit_type_param_bound(&bound); + + if visitor.is_concrete { // Add a concrete requirement for this bound. self.requirements.stmts.push(parse_quote! { const _: fn() = || { @@ -154,54 +163,16 @@ impl ImplSkeleton { } else { // Add this bound to the `where` clause. let mut bounds = Punctuated::new(); - bounds.push_value(bound); + bounds.push(bound); let pred = WherePredicate::Type(PredicateType { lifetimes: None, bounded_ty: target, colon_token: Default::default(), bounds, }); - self.where_clause.predicates.push_value(pred); + self.where_clause.predicates.push(pred); } } - - /// Whether a type is concrete within this `impl` block. - pub fn is_concrete(&self, target: &Type) -> bool { - struct ConcretenessVisitor<'a> { - /// The `impl` skeleton being added to. - skeleton: &'a ImplSkeleton, - - /// Whether the visited type is concrete. - is_concrete: bool, - } - - impl<'ast> Visit<'ast> for ConcretenessVisitor<'_> { - fn visit_lifetime(&mut self, i: &'ast Lifetime) { - self.is_concrete = self.is_concrete - && self - .skeleton - .lifetimes - .iter() - .all(|l| l.lifetime != *i); - } - - fn visit_ident(&mut self, i: &'ast Ident) { - self.is_concrete = self.is_concrete - && self.skeleton.types.iter().all(|t| t.ident != *i); - self.is_concrete = self.is_concrete - && self.skeleton.consts.iter().all(|c| c.ident != *i); - } - } - - let mut visitor = ConcretenessVisitor { - skeleton: self, - is_concrete: true, - }; - - visitor.visit_type(target); - - visitor.is_concrete - } } impl ToTokens for ImplSkeleton { @@ -235,3 +206,27 @@ impl ToTokens for ImplSkeleton { } } } + +//----------- ConcretenessVisitor -------------------------------------------- + +struct ConcretenessVisitor<'a> { + /// The `impl` skeleton being added to. + skeleton: &'a ImplSkeleton, + + /// Whether the visited type is concrete. + is_concrete: bool, +} + +impl<'ast> Visit<'ast> for ConcretenessVisitor<'_> { + fn visit_lifetime(&mut self, i: &'ast Lifetime) { + self.is_concrete = self.is_concrete + && self.skeleton.lifetimes.iter().all(|l| l.lifetime != *i); + } + + fn visit_ident(&mut self, i: &'ast Ident) { + self.is_concrete = self.is_concrete + && self.skeleton.types.iter().all(|t| t.ident != *i); + self.is_concrete = self.is_concrete + && self.skeleton.consts.iter().all(|c| c.ident != *i); + } +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 8cb26183f..33eb6eef5 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -3,8 +3,8 @@ //! [`domain`]: https://docs.rs/domain use proc_macro as pm; -use proc_macro2::TokenStream; -use quote::{quote, ToTokens}; +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote, ToTokens}; use spanned::Spanned; use syn::*; @@ -14,6 +14,275 @@ use impls::ImplSkeleton; mod repr; use repr::Repr; +//----------- SplitBytes ----------------------------------------------------- + +#[proc_macro_derive(SplitBytes)] +pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: DeriveInput) -> Result { + let data = match &input.data { + Data::Struct(data) => data, + Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'SplitBytes' can only be 'derive'd for 'struct's", + )); + } + Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'SplitBytes' can only be 'derive'd for 'struct's", + )); + } + }; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let bound = + parse_quote!(::domain::new_base::parse::SplitBytes<'bytes>); + let mut skeleton = ImplSkeleton::new(&input, false, bound); + + // Pick a non-conflicting name for the parsing lifetime. + let lifetime = [format_ident!("bytes")] + .into_iter() + .chain((0u32..).map(|i| format_ident!("bytes_{}", i))) + .find(|id| { + skeleton.lifetimes.iter().all(|l| l.lifetime.ident != *id) + }) + .map(|ident| Lifetime { + apostrophe: Span::call_site(), + ident, + }) + .unwrap(); + + // Add the parsing lifetime to the 'impl'. + if skeleton.lifetimes.len() > 0 { + let lifetimes = skeleton.lifetimes.iter(); + let param = parse_quote! { + #lifetime: #(#lifetimes)+* + }; + skeleton.lifetimes.push(param); + } else { + skeleton.lifetimes.push(parse_quote! { #lifetime }) + } + + // Establish bounds on the fields. + for field in data.fields.iter() { + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), + ); + } + + // Construct a 'Self' expression. + let self_expr = match &data.fields { + Fields::Named(_) => { + let names = data.fields.members(); + let exprs = + names.clone().map(|n| format_ident!("field_{}", n)); + quote! { + Self { + #(#names: #exprs,)* + } + } + } + + Fields::Unnamed(_) => { + let exprs = data + .fields + .members() + .map(|n| format_ident!("field_{}", n)); + quote! { + Self(#(#exprs,)*) + } + } + + Fields::Unit => quote! { Self }, + }; + + // Define 'parse_bytes()'. + let names = + data.fields.members().map(|n| format_ident!("field_{}", n)); + let tys = data.fields.iter().map(|f| &f.ty); + skeleton.contents.stmts.push(parse_quote! { + fn split_bytes( + bytes: & #lifetime [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + (Self, & #lifetime [::domain::__core::primitive::u8]), + ::domain::new_base::parse::ParseError, + > { + #(let (#names, bytes) = + <#tys as ::domain::new_base::parse::SplitBytes<#lifetime>> + ::split_bytes(bytes)?;)* + Ok((#self_expr, bytes)) + } + }); + + Ok(skeleton.into_token_stream().into()) + } + + let input = syn::parse_macro_input!(input as DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +//----------- ParseBytes ----------------------------------------------------- + +#[proc_macro_derive(ParseBytes)] +pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: DeriveInput) -> Result { + let data = match &input.data { + Data::Struct(data) => data, + Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'ParseBytes' can only be 'derive'd for 'struct's", + )); + } + Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'ParseBytes' can only be 'derive'd for 'struct's", + )); + } + }; + + // Split up the last field from the rest. + let mut fields = data.fields.iter(); + let Some(last) = fields.next_back() else { + // This type has no fields. Return a simple implementation. + assert!(input.generics.params.is_empty()); + let where_clause = input.generics.where_clause; + let name = input.ident; + + // This will tokenize to '{}', '()', or ''. + let fields = data.fields.to_token_stream(); + + return Ok(quote! { + impl <'bytes> + ::domain::new_base::parse::ParseBytes<'bytes> + for #name + #where_clause { + fn parse_bytes( + bytes: &'bytes [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + Self, + ::domain::new_base::parse::ParseError, + > { + if bytes.is_empty() { + Ok(Self #fields) + } else { + Err() + } + } + } + }); + }; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let bound = + parse_quote!(::domain::new_base::parse::ParseBytes<'bytes>); + let mut skeleton = ImplSkeleton::new(&input, false, bound); + + // Pick a non-conflicting name for the parsing lifetime. + let lifetime = [format_ident!("bytes")] + .into_iter() + .chain((0u32..).map(|i| format_ident!("bytes_{}", i))) + .find(|id| { + skeleton.lifetimes.iter().all(|l| l.lifetime.ident != *id) + }) + .map(|ident| Lifetime { + apostrophe: Span::call_site(), + ident, + }) + .unwrap(); + + // Add the parsing lifetime to the 'impl'. + if skeleton.lifetimes.len() > 0 { + let lifetimes = skeleton.lifetimes.iter(); + let param = parse_quote! { + #lifetime: #(#lifetimes)+* + }; + skeleton.lifetimes.push(param); + } else { + skeleton.lifetimes.push(parse_quote! { #lifetime }) + } + + // Establish bounds on the fields. + for field in fields.clone() { + // This field should implement 'SplitBytes'. + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), + ); + } + // The last field should implement 'ParseBytes'. + skeleton.require_bound( + last.ty.clone(), + parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), + ); + + // Construct a 'Self' expression. + let self_expr = match &data.fields { + Fields::Named(_) => { + let names = data.fields.members(); + let exprs = + names.clone().map(|n| format_ident!("field_{}", n)); + quote! { + Self { + #(#names: #exprs,)* + } + } + } + + Fields::Unnamed(_) => { + let exprs = data + .fields + .members() + .map(|n| format_ident!("field_{}", n)); + quote! { + Self(#(#exprs,)*) + } + } + + Fields::Unit => unreachable!(), + }; + + // Define 'parse_bytes()'. + let names = data + .fields + .members() + .take(fields.len()) + .map(|n| format_ident!("field_{}", n)); + let tys = fields.clone().map(|f| &f.ty); + let last_ty = &last.ty; + let last_name = + format_ident!("field_{}", data.fields.members().last().unwrap()); + skeleton.contents.stmts.push(parse_quote! { + fn parse_bytes( + bytes: & #lifetime [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + Self, + ::domain::new_base::parse::ParseError, + > { + #(let (#names, bytes) = + <#tys as ::domain::new_base::parse::SplitBytes<#lifetime>> + ::split_bytes(bytes)?;)* + let #last_name = + <#last_ty as ::domain::new_base::parse::ParseBytes<#lifetime>> + ::parse_bytes(bytes)?; + Ok(#self_expr) + } + }); + + Ok(skeleton.into_token_stream().into()) + } + + let input = syn::parse_macro_input!(input as DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + //----------- SplitBytesByRef ------------------------------------------------ #[proc_macro_derive(SplitBytesByRef)] @@ -163,7 +432,14 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { &Self, ::domain::new_base::parse::ParseError, > { - Ok(unsafe { &*bytes.as_ptr().cast::() }) + if bytes.is_empty() { + // SAFETY: 'Self' is a 'struct' with no fields, + // and so has size 0 and alignment 1. It can be + // constructed at any address. + Ok(unsafe { &*bytes.as_ptr().cast::() }) + } else { + Err(::domain::new_base::parse::ParseError) + } } fn ptr_with_address( From 7e0ef89fefca1a9fd98a3a5bb286cc34f7c228d2 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 15:41:50 +0100 Subject: [PATCH 060/167] [macros] Factor out 'new_lifetime()' --- macros/src/impls.rs | 17 +++++++++++++++-- macros/src/lib.rs | 30 +++--------------------------- 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/macros/src/impls.rs b/macros/src/impls.rs index e72ef91ec..0d97309e6 100644 --- a/macros/src/impls.rs +++ b/macros/src/impls.rs @@ -1,7 +1,7 @@ //! Helpers for generating `impl` blocks. -use proc_macro2::TokenStream; -use quote::{quote, ToTokens}; +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote, ToTokens}; use syn::{punctuated::Punctuated, visit::Visit, *}; //----------- ImplSkeleton --------------------------------------------------- @@ -173,6 +173,19 @@ impl ImplSkeleton { self.where_clause.predicates.push(pred); } } + + /// Generate a unique lifetime with the given prefix. + pub fn new_lifetime(&self, prefix: &str) -> Lifetime { + [format_ident!("{}", prefix)] + .into_iter() + .chain((0u32..).map(|i| format_ident!("{}_{}", prefix, i))) + .find(|id| self.lifetimes.iter().all(|l| l.lifetime.ident != *id)) + .map(|ident| Lifetime { + apostrophe: Span::call_site(), + ident, + }) + .unwrap() + } } impl ToTokens for ImplSkeleton { diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 33eb6eef5..046b439ca 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -3,7 +3,7 @@ //! [`domain`]: https://docs.rs/domain use proc_macro as pm; -use proc_macro2::{Span, TokenStream}; +use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; use spanned::Spanned; use syn::*; @@ -40,20 +40,8 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { parse_quote!(::domain::new_base::parse::SplitBytes<'bytes>); let mut skeleton = ImplSkeleton::new(&input, false, bound); - // Pick a non-conflicting name for the parsing lifetime. - let lifetime = [format_ident!("bytes")] - .into_iter() - .chain((0u32..).map(|i| format_ident!("bytes_{}", i))) - .find(|id| { - skeleton.lifetimes.iter().all(|l| l.lifetime.ident != *id) - }) - .map(|ident| Lifetime { - apostrophe: Span::call_site(), - ident, - }) - .unwrap(); - // Add the parsing lifetime to the 'impl'. + let lifetime = skeleton.new_lifetime("bytes"); if skeleton.lifetimes.len() > 0 { let lifetimes = skeleton.lifetimes.iter(); let param = parse_quote! { @@ -183,20 +171,8 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { parse_quote!(::domain::new_base::parse::ParseBytes<'bytes>); let mut skeleton = ImplSkeleton::new(&input, false, bound); - // Pick a non-conflicting name for the parsing lifetime. - let lifetime = [format_ident!("bytes")] - .into_iter() - .chain((0u32..).map(|i| format_ident!("bytes_{}", i))) - .find(|id| { - skeleton.lifetimes.iter().all(|l| l.lifetime.ident != *id) - }) - .map(|ident| Lifetime { - apostrophe: Span::call_site(), - ident, - }) - .unwrap(); - // Add the parsing lifetime to the 'impl'. + let lifetime = skeleton.new_lifetime("bytes"); if skeleton.lifetimes.len() > 0 { let lifetimes = skeleton.lifetimes.iter(); let param = parse_quote! { From b0f6b679e034ebcdd1d02ce8c316af3d63ee104b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 16:06:29 +0100 Subject: [PATCH 061/167] Use parsing trait derives across the 'new_*' codebase --- src/new_base/message.rs | 18 +++++- src/new_base/question.rs | 39 +++--------- src/new_base/record.rs | 30 +++++---- src/new_base/serial.rs | 5 +- src/new_edns/mod.rs | 24 +++++-- src/new_rdata/basic.rs | 132 ++++++++++++++++----------------------- src/new_rdata/edns.rs | 8 +-- src/new_rdata/ipv6.rs | 5 +- 8 files changed, 127 insertions(+), 134 deletions(-) diff --git a/src/new_base/message.rs b/src/new_base/message.rs index e60ae76ff..5d635a9e3 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -5,10 +5,14 @@ use core::fmt; use zerocopy::network_endian::U16; use zerocopy_derive::*; +use domain_macros::*; + //----------- Message -------------------------------------------------------- /// A DNS message. -#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)] +#[derive( + FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, ParseBytesByRef, +)] #[repr(C, packed)] pub struct Message { /// The message header. @@ -31,6 +35,10 @@ pub struct Message { KnownLayout, Immutable, Unaligned, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, )] #[repr(C)] pub struct Header { @@ -71,6 +79,10 @@ impl fmt::Display for Header { KnownLayout, Immutable, Unaligned, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, )] #[repr(transparent)] pub struct HeaderFlags { @@ -232,6 +244,10 @@ impl fmt::Display for HeaderFlags { KnownLayout, Immutable, Unaligned, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, )] #[repr(C)] pub struct SectionCounts { diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 029f2839f..8a9ad771f 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -2,23 +2,22 @@ use core::ops::Range; -use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::{network_endian::U16, IntoBytes}; use zerocopy_derive::*; +use domain_macros::*; + use super::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, name::RevNameBuf, - parse::{ - ParseError, ParseBytes, ParseFromMessage, SplitBytes, SplitFromMessage, - }, + parse::{ParseError, ParseFromMessage, SplitFromMessage}, Message, }; //----------- Question ------------------------------------------------------- /// A DNS question. -#[derive(Clone)] +#[derive(Clone, ParseBytes, SplitBytes)] pub struct Question { /// The domain name being requested. pub qname: N, @@ -96,32 +95,6 @@ where } } -//--- Parsing from bytes - -impl<'a, N> SplitBytes<'a> for Question -where - N: SplitBytes<'a>, -{ - fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - let (qname, rest) = N::split_bytes(bytes)?; - let (&qtype, rest) = <&QType>::split_bytes(rest)?; - let (&qclass, rest) = <&QClass>::split_bytes(rest)?; - Ok((Self::new(qname, qtype, qclass), rest)) - } -} - -impl<'a, N> ParseBytes<'a> for Question -where - N: SplitBytes<'a>, -{ - fn parse_bytes(bytes: &'a [u8]) -> Result { - let (qname, rest) = N::split_bytes(bytes)?; - let (&qtype, rest) = <&QType>::split_bytes(rest)?; - let &qclass = <&QClass>::parse_bytes(rest)?; - Ok(Self::new(qname, qtype, qclass)) - } -} - //--- Building into byte strings impl BuildInto for Question @@ -153,7 +126,9 @@ where Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] @@ -176,7 +151,9 @@ pub struct QType { Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 0b3bab85b..c02780d53 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -5,19 +5,20 @@ use core::{ ops::{Deref, Range}, }; -use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::{ network_endian::{U16, U32}, - FromBytes, IntoBytes, SizeError, + FromBytes, IntoBytes, }; use zerocopy_derive::*; +use domain_macros::*; + use super::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, name::RevNameBuf, parse::{ - ParseBytes, ParseError, ParseFromMessage, SplitBytes, - SplitFromMessage, + ParseBytes, ParseBytesByRef, ParseError, ParseFromMessage, + SplitBytes, SplitFromMessage, }, Message, }; @@ -104,8 +105,7 @@ where range: Range, ) -> Result { let message = &message.as_bytes()[..range.end]; - let message = Message::ref_from_bytes(message) - .map_err(SizeError::from) + let message = Message::parse_bytes_by_ref(message) .expect("The input range ends past the message header"); let (this, rest) = Self::split_from_message(message, range.start)?; @@ -158,9 +158,9 @@ where { fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let (rname, rest) = N::split_bytes(bytes)?; - let (&rtype, rest) = <&RType>::split_bytes(rest)?; - let (&rclass, rest) = <&RClass>::split_bytes(rest)?; - let (&ttl, rest) = <&TTL>::split_bytes(rest)?; + let (rtype, rest) = RType::split_bytes(rest)?; + let (rclass, rest) = RClass::split_bytes(rest)?; + let (ttl, rest) = TTL::split_bytes(rest)?; let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; @@ -177,9 +177,9 @@ where { fn parse_bytes(bytes: &'a [u8]) -> Result { let (rname, rest) = N::split_bytes(bytes)?; - let (&rtype, rest) = <&RType>::split_bytes(rest)?; - let (&rclass, rest) = <&RClass>::split_bytes(rest)?; - let (&ttl, rest) = <&TTL>::split_bytes(rest)?; + let (rtype, rest) = RType::split_bytes(rest)?; + let (rclass, rest) = RClass::split_bytes(rest)?; + let (ttl, rest) = TTL::split_bytes(rest)?; let (size, rest) = U16::read_from_prefix(rest)?; let size: usize = size.get().into(); let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; @@ -232,7 +232,9 @@ where Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] @@ -289,7 +291,9 @@ impl RType { Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] @@ -312,7 +316,9 @@ pub struct RClass { Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] diff --git a/src/new_base/serial.rs b/src/new_base/serial.rs index f351e1a46..2fe5e8f7c 100644 --- a/src/new_base/serial.rs +++ b/src/new_base/serial.rs @@ -8,10 +8,11 @@ use core::{ ops::{Add, AddAssign}, }; -use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::network_endian::U32; use zerocopy_derive::*; +use domain_macros::*; + //----------- Serial --------------------------------------------------------- /// A serial number. @@ -24,7 +25,9 @@ use zerocopy_derive::*; Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 8f9c7de65..d5a1d366f 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -4,15 +4,16 @@ use core::{fmt, ops::Range}; -use domain_macros::{ParseBytesByRef, SplitBytesByRef}; -use zerocopy::{network_endian::U16, FromBytes, IntoBytes}; +use zerocopy::{network_endian::U16, IntoBytes}; use zerocopy_derive::*; +use domain_macros::*; + use crate::{ new_base::{ parse::{ - ParseError, ParseBytes, ParseFromMessage, SplitBytes, - SplitFromMessage, + ParseBytes, ParseBytesByRef, ParseError, ParseFromMessage, + SplitBytes, SplitFromMessage, }, Message, }, @@ -81,7 +82,11 @@ impl<'a> SplitBytes<'a> for EdnsRecord<'a> { // Split the record size and data. let (&size, rest) = <&U16>::split_bytes(rest)?; let size: usize = size.get().into(); - let (options, rest) = Opt::ref_from_prefix_with_elems(rest, size)?; + if rest.len() < size { + return Err(ParseError); + } + let (options, rest) = rest.split_at(size); + let options = Opt::parse_bytes_by_ref(options)?; Ok(( Self { @@ -109,7 +114,10 @@ impl<'a> ParseBytes<'a> for EdnsRecord<'a> { // Split the record size and data. let (&size, rest) = <&U16>::split_bytes(rest)?; let size: usize = size.get().into(); - let options = Opt::ref_from_bytes_with_elems(rest, size)?; + if rest.len() != size { + return Err(ParseError); + } + let options = Opt::parse_bytes_by_ref(rest)?; Ok(Self { max_udp_payload, @@ -131,7 +139,9 @@ impl<'a> ParseBytes<'a> for EdnsRecord<'a> { Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] @@ -209,7 +219,9 @@ pub enum EdnsOption<'b> { Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index bfb11b9de..53251ee7f 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -10,13 +10,14 @@ use core::str::FromStr; #[cfg(feature = "std")] use std::net::Ipv4Addr; -use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::{ network_endian::{U16, U32}, IntoBytes, }; use zerocopy_derive::*; +use domain_macros::*; + use crate::new_base::{ build::{self, BuildInto, BuildIntoMessage, TruncationError}, parse::{ @@ -40,7 +41,9 @@ use crate::new_base::{ Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] @@ -112,7 +115,18 @@ impl BuildInto for A { //----------- Ns ------------------------------------------------------------- /// The authoritative name server for this domain. -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + ParseBytes, + SplitBytes, +)] #[repr(transparent)] pub struct Ns { /// The name of the authoritative server. @@ -141,14 +155,6 @@ impl BuildIntoMessage for Ns { } } -//--- Parsing from bytes - -impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for Ns { - fn parse_bytes(bytes: &'a [u8]) -> Result { - N::parse_bytes(bytes).map(|name| Self { name }) - } -} - //--- Building into bytes impl BuildInto for Ns { @@ -163,7 +169,18 @@ impl BuildInto for Ns { //----------- Cname ---------------------------------------------------------- /// The canonical name for this domain. -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + ParseBytes, + SplitBytes, +)] #[repr(transparent)] pub struct CName { /// The canonical name. @@ -192,14 +209,6 @@ impl BuildIntoMessage for CName { } } -//--- Parsing from bytes - -impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for CName { - fn parse_bytes(bytes: &'a [u8]) -> Result { - N::parse_bytes(bytes).map(|name| Self { name }) - } -} - //--- Building into bytes impl BuildInto for CName { @@ -214,7 +223,7 @@ impl BuildInto for CName { //----------- Soa ------------------------------------------------------------ /// The start of a zone of authority. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, ParseBytes, SplitBytes)] pub struct Soa { /// The name server which provided this zone. pub mname: N, @@ -284,30 +293,6 @@ impl BuildIntoMessage for Soa { } } -//--- Parsing from bytes - -impl<'a, N: SplitBytes<'a>> ParseBytes<'a> for Soa { - fn parse_bytes(bytes: &'a [u8]) -> Result { - let (mname, rest) = N::split_bytes(bytes)?; - let (rname, rest) = N::split_bytes(rest)?; - let (&serial, rest) = <&Serial>::split_bytes(rest)?; - let (&refresh, rest) = <&U32>::split_bytes(rest)?; - let (&retry, rest) = <&U32>::split_bytes(rest)?; - let (&expire, rest) = <&U32>::split_bytes(rest)?; - let &minimum = <&U32>::parse_bytes(rest)?; - - Ok(Self { - mname, - rname, - serial, - refresh, - retry, - expire, - minimum, - }) - } -} - //--- Building into byte strings impl BuildInto for Soa { @@ -395,7 +380,18 @@ impl BuildInto for Wks { //----------- Ptr ------------------------------------------------------------ /// A pointer to another domain name. -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + ParseBytes, + SplitBytes, +)] #[repr(transparent)] pub struct Ptr { /// The referenced domain name. @@ -424,14 +420,6 @@ impl BuildIntoMessage for Ptr { } } -//--- Parsing from bytes - -impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for Ptr { - fn parse_bytes(bytes: &'a [u8]) -> Result { - N::parse_bytes(bytes).map(|name| Self { name }) - } -} - //--- Building into bytes impl BuildInto for Ptr { @@ -446,7 +434,7 @@ impl BuildInto for Ptr { //----------- HInfo ---------------------------------------------------------- /// Information about the host computer. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, ParseBytes, SplitBytes)] pub struct HInfo<'a> { /// The CPU type. pub cpu: &'a CharStr, @@ -484,16 +472,6 @@ impl BuildIntoMessage for HInfo<'_> { } } -//--- Parsing from bytes - -impl<'a> ParseBytes<'a> for HInfo<'a> { - fn parse_bytes(bytes: &'a [u8]) -> Result { - let (cpu, rest) = <&CharStr>::split_bytes(bytes)?; - let os = <&CharStr>::parse_bytes(rest)?; - Ok(Self { cpu, os }) - } -} - //--- Building into bytes impl BuildInto for HInfo<'_> { @@ -510,7 +488,18 @@ impl BuildInto for HInfo<'_> { //----------- Mx ------------------------------------------------------------- /// A host that can exchange mail for this domain. -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + ParseBytes, + SplitBytes, +)] #[repr(C)] pub struct Mx { /// The preference for this host over others. @@ -551,19 +540,6 @@ impl BuildIntoMessage for Mx { } } -//--- Parsing from bytes - -impl<'a, N: ParseBytes<'a>> ParseBytes<'a> for Mx { - fn parse_bytes(bytes: &'a [u8]) -> Result { - let (&preference, rest) = <&U16>::split_bytes(bytes)?; - let exchange = N::parse_bytes(rest)?; - Ok(Self { - preference, - exchange, - }) - } -} - //--- Building into byte strings impl BuildInto for Mx { diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index 89e146062..4f84ba837 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -4,6 +4,8 @@ use zerocopy_derive::*; +use domain_macros::*; + use crate::new_base::build::{ self, BuildInto, BuildIntoMessage, TruncationError, }; @@ -17,13 +19,11 @@ use crate::new_base::build::{ PartialOrd, Ord, Hash, - FromBytes, IntoBytes, - KnownLayout, Immutable, - Unaligned, + ParseBytesByRef, )] -#[repr(C)] // 'derive(KnownLayout)' doesn't work with 'repr(transparent)'. +#[repr(transparent)] pub struct Opt { /// The raw serialized options. contents: [u8], diff --git a/src/new_rdata/ipv6.rs b/src/new_rdata/ipv6.rs index fdb2aa674..77df07cc5 100644 --- a/src/new_rdata/ipv6.rs +++ b/src/new_rdata/ipv6.rs @@ -8,10 +8,11 @@ use core::{fmt, str::FromStr}; #[cfg(feature = "std")] use std::net::Ipv6Addr; -use domain_macros::{ParseBytesByRef, SplitBytesByRef}; use zerocopy::IntoBytes; use zerocopy_derive::*; +use domain_macros::*; + use crate::new_base::build::{ self, BuildInto, BuildIntoMessage, TruncationError, }; @@ -30,7 +31,9 @@ use crate::new_base::build::{ Hash, IntoBytes, Immutable, + ParseBytes, ParseBytesByRef, + SplitBytes, SplitBytesByRef, )] #[repr(transparent)] From 4a5d343e7f38b9380ea18f40734fe8ee05b57f9e Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 17:54:07 +0100 Subject: [PATCH 062/167] Impl and use derives for 'AsBytes' and 'BuildBytes' --- macros/src/impls.rs | 13 ++- macros/src/lib.rs | 134 +++++++++++++++++++++++++++--- src/new_base/build/builder.rs | 4 +- src/new_base/build/mod.rs | 116 ++++++++++++++++++++++---- src/new_base/charstr.rs | 10 +-- src/new_base/message.rs | 8 +- src/new_base/name/label.rs | 6 +- src/new_base/name/reversed.rs | 19 +++-- src/new_base/question.rs | 32 ++------ src/new_base/record.rs | 46 ++++------- src/new_base/serial.rs | 5 +- src/new_rdata/basic.rs | 148 ++++++---------------------------- src/new_rdata/edns.rs | 26 +----- src/new_rdata/ipv6.rs | 20 +---- src/new_rdata/mod.rs | 33 ++++---- 15 files changed, 332 insertions(+), 288 deletions(-) diff --git a/macros/src/impls.rs b/macros/src/impls.rs index 0d97309e6..5e3b884a0 100644 --- a/macros/src/impls.rs +++ b/macros/src/impls.rs @@ -21,7 +21,7 @@ pub struct ImplSkeleton { pub unsafety: Option, /// The trait being implemented. - pub bound: Path, + pub bound: Option, /// The type being implemented on. pub subject: Path, @@ -38,7 +38,7 @@ pub struct ImplSkeleton { impl ImplSkeleton { /// Construct an [`ImplSkeleton`] for a [`DeriveInput`]. - pub fn new(input: &DeriveInput, unsafety: bool, bound: Path) -> Self { + pub fn new(input: &DeriveInput, unsafety: bool) -> Self { let mut lifetimes = Vec::new(); let mut types = Vec::new(); let mut consts = Vec::new(); @@ -130,7 +130,7 @@ impl ImplSkeleton { types, consts, unsafety, - bound, + bound: None, subject, where_clause, contents, @@ -202,10 +202,15 @@ impl ToTokens for ImplSkeleton { requirements, } = self; + let target = match bound { + Some(bound) => quote!(#bound for #subject), + None => quote!(#subject), + }; + quote! { #unsafety impl<#(#lifetimes,)* #(#types,)* #(#consts,)*> - #bound for #subject + #target #where_clause #contents } diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 046b439ca..2885844af 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -36,12 +36,13 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { }; // Construct an 'ImplSkeleton' so that we can add trait bounds. - let bound = - parse_quote!(::domain::new_base::parse::SplitBytes<'bytes>); - let mut skeleton = ImplSkeleton::new(&input, false, bound); + let mut skeleton = ImplSkeleton::new(&input, false); // Add the parsing lifetime to the 'impl'. let lifetime = skeleton.new_lifetime("bytes"); + skeleton.bound = Some( + parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), + ); if skeleton.lifetimes.len() > 0 { let lifetimes = skeleton.lifetimes.iter(); let param = parse_quote! { @@ -167,12 +168,13 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { }; // Construct an 'ImplSkeleton' so that we can add trait bounds. - let bound = - parse_quote!(::domain::new_base::parse::ParseBytes<'bytes>); - let mut skeleton = ImplSkeleton::new(&input, false, bound); + let mut skeleton = ImplSkeleton::new(&input, false); // Add the parsing lifetime to the 'impl'. let lifetime = skeleton.new_lifetime("bytes"); + skeleton.bound = Some( + parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), + ); if skeleton.lifetimes.len() > 0 { let lifetimes = skeleton.lifetimes.iter(); let param = parse_quote! { @@ -311,8 +313,9 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { }; // Construct an 'ImplSkeleton' so that we can add trait bounds. - let bound = parse_quote!(::domain::new_base::parse::SplitBytesByRef); - let mut skeleton = ImplSkeleton::new(&input, true, bound); + let mut skeleton = ImplSkeleton::new(&input, true); + skeleton.bound = + Some(parse_quote!(::domain::new_base::parse::SplitBytesByRef)); // Establish bounds on the fields. for field in data.fields.iter() { @@ -429,8 +432,9 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { }; // Construct an 'ImplSkeleton' so that we can add trait bounds. - let bound = parse_quote!(::domain::new_base::parse::ParseBytesByRef); - let mut skeleton = ImplSkeleton::new(&input, true, bound); + let mut skeleton = ImplSkeleton::new(&input, true); + skeleton.bound = + Some(parse_quote!(::domain::new_base::parse::ParseBytesByRef)); // Establish bounds on the fields. for field in fields.clone() { @@ -505,3 +509,113 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { .unwrap_or_else(syn::Error::into_compile_error) .into() } + +//----------- BuildBytes ----------------------------------------------------- + +#[proc_macro_derive(BuildBytes)] +pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: DeriveInput) -> Result { + let data = match &input.data { + Data::Struct(data) => data, + Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'BuildBytes' can only be 'derive'd for 'struct's", + )); + } + Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'BuildBytes' can only be 'derive'd for 'struct's", + )); + } + }; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let mut skeleton = ImplSkeleton::new(&input, false); + skeleton.bound = + Some(parse_quote!(::domain::new_base::build::BuildBytes)); + + // Get a lifetime for the input buffer. + let lifetime = skeleton.new_lifetime("bytes"); + + // Establish bounds on the fields. + for field in data.fields.iter() { + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::build::BuildBytes), + ); + } + + // Define 'build_bytes()'. + let names = data.fields.members(); + let tys = data.fields.iter().map(|f| &f.ty); + skeleton.contents.stmts.push(parse_quote! { + fn build_bytes<#lifetime>( + &self, + mut bytes: & #lifetime mut [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + & #lifetime mut [::domain::__core::primitive::u8], + ::domain::new_base::build::TruncationError, + > { + #(bytes = <#tys as ::domain::new_base::build::BuildBytes> + ::build_bytes(&self.#names, bytes)?;)* + Ok(bytes) + } + }); + + Ok(skeleton.into_token_stream().into()) + } + + let input = syn::parse_macro_input!(input as DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +//----------- AsBytes -------------------------------------------------------- + +#[proc_macro_derive(AsBytes)] +pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: DeriveInput) -> Result { + let data = match &input.data { + Data::Struct(data) => data, + Data::Enum(data) => { + return Err(Error::new_spanned( + data.enum_token, + "'AsBytes' can only be 'derive'd for 'struct's", + )); + } + Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'AsBytes' can only be 'derive'd for 'struct's", + )); + } + }; + + let _ = Repr::determine(&input.attrs, "AsBytes")?; + + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let mut skeleton = ImplSkeleton::new(&input, true); + skeleton.bound = + Some(parse_quote!(::domain::new_base::build::AsBytes)); + + // Establish bounds on the fields. + for field in data.fields.iter() { + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::build::AsBytes), + ); + } + + // The default implementation of 'as_bytes()' works perfectly. + + Ok(skeleton.into_token_stream().into()) + } + + let input = syn::parse_macro_input!(input as DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 75a9cfc69..9245b9011 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -10,7 +10,7 @@ use zerocopy::{FromBytes, IntoBytes, SizeError}; use crate::new_base::{name::RevName, Header, Message}; -use super::{BuildInto, TruncationError}; +use super::{BuildBytes, TruncationError}; //----------- Builder -------------------------------------------------------- @@ -303,7 +303,7 @@ impl Builder<'_> { name: &RevName, ) -> Result<(), TruncationError> { // TODO: Perform name compression. - name.build_into(self.uninitialized())?; + name.build_bytes(self.uninitialized())?; self.mark_appended(name.len()); Ok(()) } diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 108cc76f0..56670e922 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -2,6 +2,8 @@ use core::fmt; +use zerocopy::network_endian::{U16, U32}; + mod builder; pub use builder::{Builder, BuilderContext}; @@ -42,36 +44,38 @@ impl BuildIntoMessage for [u8] { //----------- Low-level building traits -------------------------------------- -/// Building into a byte string. -pub trait BuildInto { - /// Append this value to the byte string. +/// Serializing into a byte string. +pub trait BuildBytes { + /// Serialize into a byte string. /// - /// If the byte string is long enough to fit the message, the remaining - /// (unfilled) part of the byte string is returned. Otherwise, a - /// [`TruncationError`] is returned. - fn build_into<'b>( + /// `self` is serialized into a byte string and written to the given + /// buffer. If the buffer is large enough, the whole object is written + /// and the remaining (unmodified) part of the buffer is returned. + /// + /// if the buffer is too small, a [`TruncationError`] is returned (and + /// parts of the buffer may be modified). + fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError>; } -impl BuildInto for &T { - fn build_into<'b>( +impl BuildBytes for &T { + fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { - (**self).build_into(bytes) + T::build_bytes(*self, bytes) } } -impl BuildInto for [u8] { - fn build_into<'b>( +impl BuildBytes for u8 { + fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { - if self.len() <= bytes.len() { - let (bytes, rest) = bytes.split_at_mut(self.len()); - bytes.copy_from_slice(self); + if let Some((elem, rest)) = bytes.split_first_mut() { + *elem = *self; Ok(rest) } else { Err(TruncationError) @@ -79,6 +83,88 @@ impl BuildInto for [u8] { } } +impl BuildBytes for U16 { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.as_bytes().build_bytes(bytes) + } +} + +impl BuildBytes for U32 { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.as_bytes().build_bytes(bytes) + } +} + +impl BuildBytes for [T] { + fn build_bytes<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + for elem in self { + bytes = elem.build_bytes(bytes)?; + } + Ok(bytes) + } +} + +impl BuildBytes for [T; N] { + fn build_bytes<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + for elem in self { + bytes = elem.build_bytes(bytes)?; + } + Ok(bytes) + } +} + +/// Interpreting a value as a byte string. +/// +/// # Safety +/// +/// A type `T` can soundly implement [`AsBytes`] if and only if: +/// +/// - It has no padding bytes. +/// - It has no interior mutability. +pub unsafe trait AsBytes { + /// Interpret this value as a sequence of bytes. + /// + /// ## Invariants + /// + /// For the statement `let bytes = this.as_bytes();`, + /// + /// - `bytes.as_ptr() as usize == this as *const _ as usize`. + /// - `bytes.len() == core::mem::size_of_val(this)`. + /// + /// The default implementation automatically satisfies these invariants. + fn as_bytes(&self) -> &[u8] { + // SAFETY: + // - 'Self' has no padding bytes and no interior mutability. + // - Its size in memory is exactly 'size_of_val(self)'. + unsafe { + core::slice::from_raw_parts( + self as *const Self as *const u8, + core::mem::size_of_val(self), + ) + } + } +} + +unsafe impl AsBytes for u8 {} + +unsafe impl AsBytes for [T] {} +unsafe impl AsBytes for [T; N] {} + +unsafe impl AsBytes for U16 {} +unsafe impl AsBytes for U32 {} + //----------- TruncationError ------------------------------------------------ /// A DNS message did not fit in a buffer. diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index 57f888c27..2a82e95fa 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -3,10 +3,9 @@ use core::{fmt, ops::Range}; use zerocopy::IntoBytes; -use zerocopy_derive::*; use super::{ - build::{self, BuildInto, BuildIntoMessage, TruncationError}, + build::{self, BuildBytes, BuildIntoMessage, TruncationError}, parse::{ ParseBytes, ParseError, ParseFromMessage, SplitBytes, SplitFromMessage, @@ -17,7 +16,6 @@ use super::{ //----------- CharStr -------------------------------------------------------- /// A DNS "character string". -#[derive(Immutable, Unaligned)] #[repr(transparent)] pub struct CharStr { /// The underlying octets. @@ -93,15 +91,15 @@ impl<'a> ParseBytes<'a> for &'a CharStr { //--- Building into byte strings -impl BuildInto for CharStr { - fn build_into<'b>( +impl BuildBytes for CharStr { + fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { let (length, bytes) = bytes.split_first_mut().ok_or(TruncationError)?; *length = self.octets.len() as u8; - self.octets.build_into(bytes) + self.octets.build_bytes(bytes) } } diff --git a/src/new_base/message.rs b/src/new_base/message.rs index 5d635a9e3..3307609bb 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -5,7 +5,7 @@ use core::fmt; use zerocopy::network_endian::U16; use zerocopy_derive::*; -use domain_macros::*; +use domain_macros::{AsBytes, *}; //----------- Message -------------------------------------------------------- @@ -35,6 +35,8 @@ pub struct Message { KnownLayout, Immutable, Unaligned, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -79,6 +81,8 @@ impl fmt::Display for Header { KnownLayout, Immutable, Unaligned, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -244,6 +248,8 @@ impl fmt::Display for HeaderFlags { KnownLayout, Immutable, Unaligned, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index 7068e2e15..78ef94008 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -7,16 +7,16 @@ use core::{ iter::FusedIterator, }; -use zerocopy_derive::*; +use domain_macros::AsBytes; -use crate::new_base::parse::{ParseError, ParseBytes, SplitBytes}; +use crate::new_base::parse::{ParseBytes, ParseError, SplitBytes}; //----------- Label ---------------------------------------------------------- /// A label in a domain name. /// /// A label contains up to 63 bytes of arbitrary data. -#[derive(IntoBytes, Immutable, Unaligned)] +#[derive(AsBytes)] #[repr(transparent)] pub struct Label([u8]); diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index ee7b73b9e..6fae3c0f2 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -9,12 +9,12 @@ use core::{ }; use zerocopy::IntoBytes; -use zerocopy_derive::*; use crate::new_base::{ - build::{self, BuildInto, BuildIntoMessage, TruncationError}, + build::{self, BuildBytes, BuildIntoMessage, TruncationError}, parse::{ - ParseError, ParseBytes, ParseFromMessage, SplitBytes, SplitFromMessage, + ParseBytes, ParseError, ParseFromMessage, SplitBytes, + SplitFromMessage, }, Message, }; @@ -30,7 +30,6 @@ use super::LabelIter; /// use, making many common operations (e.g. comparing and ordering domain /// names) more computationally expensive. A [`RevName`] stores the labels in /// reversed order for more efficient use. -#[derive(Immutable, Unaligned)] #[repr(transparent)] pub struct RevName([u8]); @@ -113,8 +112,8 @@ impl BuildIntoMessage for RevName { //--- Building into byte strings -impl BuildInto for RevName { - fn build_into<'b>( +impl BuildBytes for RevName { + fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { @@ -213,7 +212,7 @@ impl fmt::Debug for RevName { //----------- RevNameBuf ----------------------------------------------------- /// A 256-byte buffer containing a [`RevName`]. -#[derive(Clone, Immutable, Unaligned)] +#[derive(Clone)] #[repr(C)] // make layout compatible with '[u8; 256]' pub struct RevNameBuf { /// The position of the root label in the buffer. @@ -422,12 +421,12 @@ impl<'a> ParseBytes<'a> for RevNameBuf { //--- Building into byte strings -impl BuildInto for RevNameBuf { - fn build_into<'b>( +impl BuildBytes for RevNameBuf { + fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { - (**self).build_into(bytes) + (**self).build_bytes(bytes) } } diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 8a9ad771f..4e93951aa 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -2,13 +2,12 @@ use core::ops::Range; -use zerocopy::{network_endian::U16, IntoBytes}; -use zerocopy_derive::*; +use zerocopy::network_endian::U16; use domain_macros::*; use super::{ - build::{self, BuildInto, BuildIntoMessage, TruncationError}, + build::{self, AsBytes, BuildIntoMessage, TruncationError}, name::RevNameBuf, parse::{ParseError, ParseFromMessage, SplitFromMessage}, Message, @@ -17,7 +16,7 @@ use super::{ //----------- Question ------------------------------------------------------- /// A DNS question. -#[derive(Clone, ParseBytes, SplitBytes)] +#[derive(Clone, BuildBytes, ParseBytes, SplitBytes)] pub struct Question { /// The domain name being requested. pub qname: N, @@ -95,23 +94,6 @@ where } } -//--- Building into byte strings - -impl BuildInto for Question -where - N: BuildInto, -{ - fn build_into<'b>( - &self, - mut bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - bytes = self.qname.build_into(bytes)?; - bytes = self.qtype.as_bytes().build_into(bytes)?; - bytes = self.qclass.as_bytes().build_into(bytes)?; - Ok(bytes) - } -} - //----------- QType ---------------------------------------------------------- /// The type of a question. @@ -124,8 +106,8 @@ where PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -149,8 +131,8 @@ pub struct QType { PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, diff --git a/src/new_base/record.rs b/src/new_base/record.rs index c02780d53..391b95dee 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -9,12 +9,11 @@ use zerocopy::{ network_endian::{U16, U32}, FromBytes, IntoBytes, }; -use zerocopy_derive::*; use domain_macros::*; use super::{ - build::{self, BuildInto, BuildIntoMessage, TruncationError}, + build::{self, AsBytes, BuildBytes, BuildIntoMessage, TruncationError}, name::RevNameBuf, parse::{ ParseBytes, ParseBytesByRef, ParseError, ParseFromMessage, @@ -191,25 +190,25 @@ where //--- Building into byte strings -impl BuildInto for Record +impl BuildBytes for Record where - N: BuildInto, - D: BuildInto, + N: BuildBytes, + D: BuildBytes, { - fn build_into<'b>( + fn build_bytes<'b>( &self, mut bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { - bytes = self.rname.build_into(bytes)?; - bytes = self.rtype.as_bytes().build_into(bytes)?; - bytes = self.rclass.as_bytes().build_into(bytes)?; - bytes = self.ttl.as_bytes().build_into(bytes)?; + bytes = self.rname.build_bytes(bytes)?; + bytes = self.rtype.as_bytes().build_bytes(bytes)?; + bytes = self.rclass.as_bytes().build_bytes(bytes)?; + bytes = self.ttl.as_bytes().build_bytes(bytes)?; let (size, bytes) = ::mut_from_prefix(bytes).map_err(|_| TruncationError)?; let bytes_len = bytes.len(); - let rest = self.rdata.build_into(bytes)?; + let rest = self.rdata.build_bytes(bytes)?; *size = u16::try_from(bytes_len - rest.len()) .expect("the record data never exceeds 64KiB") .into(); @@ -230,8 +229,8 @@ where PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -289,8 +288,8 @@ impl RType { PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -314,8 +313,8 @@ pub struct RClass { PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -351,7 +350,7 @@ pub trait ParseRecordData<'a>: Sized { //----------- UnparsedRecordData --------------------------------------------- /// Unparsed DNS record data. -#[derive(Immutable, Unaligned)] +#[derive(AsBytes, BuildBytes)] #[repr(transparent)] pub struct UnparsedRecordData([u8]); @@ -398,17 +397,6 @@ impl BuildIntoMessage for UnparsedRecordData { } } -//--- Building into byte strings - -impl BuildInto for UnparsedRecordData { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.0.build_into(bytes) - } -} - //--- Access to the underlying bytes impl Deref for UnparsedRecordData { diff --git a/src/new_base/serial.rs b/src/new_base/serial.rs index 2fe5e8f7c..eaccf32f2 100644 --- a/src/new_base/serial.rs +++ b/src/new_base/serial.rs @@ -9,7 +9,6 @@ use core::{ }; use zerocopy::network_endian::U32; -use zerocopy_derive::*; use domain_macros::*; @@ -23,8 +22,8 @@ use domain_macros::*; PartialEq, Eq, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 53251ee7f..d9e8829ac 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -10,16 +10,12 @@ use core::str::FromStr; #[cfg(feature = "std")] use std::net::Ipv4Addr; -use zerocopy::{ - network_endian::{U16, U32}, - IntoBytes, -}; -use zerocopy_derive::*; +use zerocopy::network_endian::{U16, U32}; use domain_macros::*; use crate::new_base::{ - build::{self, BuildInto, BuildIntoMessage, TruncationError}, + build::{self, AsBytes, BuildIntoMessage, TruncationError}, parse::{ ParseBytes, ParseError, ParseFromMessage, SplitBytes, SplitFromMessage, @@ -39,8 +35,8 @@ use crate::new_base::{ PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -101,17 +97,6 @@ impl BuildIntoMessage for A { } } -//--- Building into byte strings - -impl BuildInto for A { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.as_bytes().build_into(bytes) - } -} - //----------- Ns ------------------------------------------------------------- /// The authoritative name server for this domain. @@ -124,6 +109,7 @@ impl BuildInto for A { PartialOrd, Ord, Hash, + BuildBytes, ParseBytes, SplitBytes, )] @@ -155,17 +141,6 @@ impl BuildIntoMessage for Ns { } } -//--- Building into bytes - -impl BuildInto for Ns { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.name.build_into(bytes) - } -} - //----------- Cname ---------------------------------------------------------- /// The canonical name for this domain. @@ -178,6 +153,7 @@ impl BuildInto for Ns { PartialOrd, Ord, Hash, + BuildBytes, ParseBytes, SplitBytes, )] @@ -209,21 +185,20 @@ impl BuildIntoMessage for CName { } } -//--- Building into bytes - -impl BuildInto for CName { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.name.build_into(bytes) - } -} - //----------- Soa ------------------------------------------------------------ /// The start of a zone of authority. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, ParseBytes, SplitBytes)] +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + Hash, + BuildBytes, + ParseBytes, + SplitBytes, +)] pub struct Soa { /// The name server which provided this zone. pub mname: N, @@ -293,28 +268,10 @@ impl BuildIntoMessage for Soa { } } -//--- Building into byte strings - -impl BuildInto for Soa { - fn build_into<'b>( - &self, - mut bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - bytes = self.mname.build_into(bytes)?; - bytes = self.rname.build_into(bytes)?; - bytes = self.serial.as_bytes().build_into(bytes)?; - bytes = self.refresh.as_bytes().build_into(bytes)?; - bytes = self.retry.as_bytes().build_into(bytes)?; - bytes = self.expire.as_bytes().build_into(bytes)?; - bytes = self.minimum.as_bytes().build_into(bytes)?; - Ok(bytes) - } -} - //----------- Wks ------------------------------------------------------------ /// Well-known services supported on this domain. -#[derive(IntoBytes, Immutable, ParseBytesByRef)] +#[derive(AsBytes, BuildBytes, ParseBytesByRef)] #[repr(C, packed)] pub struct Wks { /// The address of the host providing these services. @@ -366,17 +323,6 @@ impl BuildIntoMessage for Wks { } } -//--- Building into byte strings - -impl BuildInto for Wks { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.as_bytes().build_into(bytes) - } -} - //----------- Ptr ------------------------------------------------------------ /// A pointer to another domain name. @@ -389,6 +335,7 @@ impl BuildInto for Wks { PartialOrd, Ord, Hash, + BuildBytes, ParseBytes, SplitBytes, )] @@ -420,21 +367,10 @@ impl BuildIntoMessage for Ptr { } } -//--- Building into bytes - -impl BuildInto for Ptr { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.name.build_into(bytes) - } -} - //----------- HInfo ---------------------------------------------------------- /// Information about the host computer. -#[derive(Clone, Debug, PartialEq, Eq, ParseBytes, SplitBytes)] +#[derive(Clone, Debug, PartialEq, Eq, BuildBytes, ParseBytes, SplitBytes)] pub struct HInfo<'a> { /// The CPU type. pub cpu: &'a CharStr, @@ -450,6 +386,8 @@ impl<'a> ParseFromMessage<'a> for HInfo<'a> { message: &'a Message, range: Range, ) -> Result { + use zerocopy::IntoBytes; + message .as_bytes() .get(range) @@ -472,19 +410,6 @@ impl BuildIntoMessage for HInfo<'_> { } } -//--- Building into bytes - -impl BuildInto for HInfo<'_> { - fn build_into<'b>( - &self, - mut bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - bytes = self.cpu.build_into(bytes)?; - bytes = self.os.build_into(bytes)?; - Ok(bytes) - } -} - //----------- Mx ------------------------------------------------------------- /// A host that can exchange mail for this domain. @@ -497,6 +422,7 @@ impl BuildInto for HInfo<'_> { PartialOrd, Ord, Hash, + BuildBytes, ParseBytes, SplitBytes, )] @@ -540,23 +466,10 @@ impl BuildIntoMessage for Mx { } } -//--- Building into byte strings - -impl BuildInto for Mx { - fn build_into<'b>( - &self, - mut bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - bytes = self.preference.as_bytes().build_into(bytes)?; - bytes = self.exchange.build_into(bytes)?; - Ok(bytes) - } -} - //----------- Txt ------------------------------------------------------------ /// Free-form text strings about this domain. -#[derive(IntoBytes, Immutable, Unaligned)] +#[derive(AsBytes, BuildBytes)] #[repr(transparent)] pub struct Txt { /// The text strings, as concatenated [`CharStr`]s. @@ -588,6 +501,8 @@ impl<'a> ParseFromMessage<'a> for &'a Txt { message: &'a Message, range: Range, ) -> Result { + use zerocopy::IntoBytes; + message .as_bytes() .get(range) @@ -622,17 +537,6 @@ impl<'a> ParseBytes<'a> for &'a Txt { } } -//--- Building into byte strings - -impl BuildInto for Txt { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.content.build_into(bytes) - } -} - //--- Formatting impl fmt::Debug for Txt { diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index 4f84ba837..c53a715a7 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -2,26 +2,15 @@ //! //! See [RFC 6891](https://datatracker.ietf.org/doc/html/rfc6891). -use zerocopy_derive::*; - use domain_macros::*; -use crate::new_base::build::{ - self, BuildInto, BuildIntoMessage, TruncationError, -}; +use crate::new_base::build::{self, BuildIntoMessage, TruncationError}; //----------- Opt ------------------------------------------------------------ /// Extended DNS options. #[derive( - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - IntoBytes, - Immutable, - ParseBytesByRef, + PartialEq, Eq, PartialOrd, Ord, Hash, AsBytes, BuildBytes, ParseBytesByRef, )] #[repr(transparent)] pub struct Opt { @@ -42,14 +31,3 @@ impl BuildIntoMessage for Opt { self.contents.build_into_message(builder) } } - -//--- Building into byte strings - -impl BuildInto for Opt { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.contents.build_into(bytes) - } -} diff --git a/src/new_rdata/ipv6.rs b/src/new_rdata/ipv6.rs index 77df07cc5..fb3f9d30e 100644 --- a/src/new_rdata/ipv6.rs +++ b/src/new_rdata/ipv6.rs @@ -8,13 +8,10 @@ use core::{fmt, str::FromStr}; #[cfg(feature = "std")] use std::net::Ipv6Addr; -use zerocopy::IntoBytes; -use zerocopy_derive::*; - use domain_macros::*; use crate::new_base::build::{ - self, BuildInto, BuildIntoMessage, TruncationError, + self, AsBytes, BuildIntoMessage, TruncationError, }; //----------- Aaaa ----------------------------------------------------------- @@ -29,8 +26,8 @@ use crate::new_base::build::{ PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -90,14 +87,3 @@ impl BuildIntoMessage for Aaaa { self.as_bytes().build_into_message(builder) } } - -//--- Building into byte strings - -impl BuildInto for Aaaa { - fn build_into<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.as_bytes().build_into(bytes) - } -} diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 0cf020988..1be038e45 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -2,11 +2,10 @@ use core::ops::Range; -use domain_macros::ParseBytesByRef; -use zerocopy_derive::*; +use domain_macros::*; use crate::new_base::{ - build::{BuildInto, BuildIntoMessage, Builder, TruncationError}, + build::{BuildBytes, BuildIntoMessage, Builder, TruncationError}, parse::{ ParseBytes, ParseError, ParseFromMessage, SplitBytes, SplitFromMessage, @@ -150,23 +149,23 @@ impl BuildIntoMessage for RecordData<'_, N> { } } -impl BuildInto for RecordData<'_, N> { - fn build_into<'b>( +impl BuildBytes for RecordData<'_, N> { + fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { match self { - Self::A(r) => r.build_into(bytes), - Self::Ns(r) => r.build_into(bytes), - Self::CName(r) => r.build_into(bytes), - Self::Soa(r) => r.build_into(bytes), - Self::Wks(r) => r.build_into(bytes), - Self::Ptr(r) => r.build_into(bytes), - Self::HInfo(r) => r.build_into(bytes), - Self::Txt(r) => r.build_into(bytes), - Self::Aaaa(r) => r.build_into(bytes), - Self::Mx(r) => r.build_into(bytes), - Self::Unknown(_, r) => r.octets.build_into(bytes), + Self::A(r) => r.build_bytes(bytes), + Self::Ns(r) => r.build_bytes(bytes), + Self::CName(r) => r.build_bytes(bytes), + Self::Soa(r) => r.build_bytes(bytes), + Self::Wks(r) => r.build_bytes(bytes), + Self::Ptr(r) => r.build_bytes(bytes), + Self::HInfo(r) => r.build_bytes(bytes), + Self::Txt(r) => r.build_bytes(bytes), + Self::Aaaa(r) => r.build_bytes(bytes), + Self::Mx(r) => r.build_bytes(bytes), + Self::Unknown(_, r) => r.build_bytes(bytes), } } } @@ -174,7 +173,7 @@ impl BuildInto for RecordData<'_, N> { //----------- UnknownRecordData ---------------------------------------------- /// Data for an unknown DNS record type. -#[derive(Debug, IntoBytes, Immutable, ParseBytesByRef)] +#[derive(Debug, AsBytes, BuildBytes, ParseBytesByRef)] #[repr(C)] pub struct UnknownRecordData { /// The unparsed option data. From d1f94ba1bfa6ce21f933642524f81a530d3b3124 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 17:58:22 +0100 Subject: [PATCH 063/167] [macros] Minor fixes as per clippy --- macros/src/lib.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 2885844af..605cac3be 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -43,7 +43,7 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { skeleton.bound = Some( parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), ); - if skeleton.lifetimes.len() > 0 { + if !skeleton.lifetimes.is_empty() { let lifetimes = skeleton.lifetimes.iter(); let param = parse_quote! { #lifetime: #(#lifetimes)+* @@ -105,7 +105,7 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { } }); - Ok(skeleton.into_token_stream().into()) + Ok(skeleton.into_token_stream()) } let input = syn::parse_macro_input!(input as DeriveInput); @@ -175,7 +175,7 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { skeleton.bound = Some( parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), ); - if skeleton.lifetimes.len() > 0 { + if !skeleton.lifetimes.is_empty() { let lifetimes = skeleton.lifetimes.iter(); let param = parse_quote! { #lifetime: #(#lifetimes)+* @@ -252,7 +252,7 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { } }); - Ok(skeleton.into_token_stream().into()) + Ok(skeleton.into_token_stream()) } let input = syn::parse_macro_input!(input as DeriveInput); @@ -360,7 +360,7 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { } }); - Ok(skeleton.into_token_stream().into()) + Ok(skeleton.into_token_stream()) } let input = syn::parse_macro_input!(input as DeriveInput); @@ -501,7 +501,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { } }); - Ok(skeleton.into_token_stream().into()) + Ok(skeleton.into_token_stream()) } let input = syn::parse_macro_input!(input as DeriveInput); @@ -564,7 +564,7 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { } }); - Ok(skeleton.into_token_stream().into()) + Ok(skeleton.into_token_stream()) } let input = syn::parse_macro_input!(input as DeriveInput); @@ -611,7 +611,7 @@ pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { // The default implementation of 'as_bytes()' works perfectly. - Ok(skeleton.into_token_stream().into()) + Ok(skeleton.into_token_stream()) } let input = syn::parse_macro_input!(input as DeriveInput); From c0368ea875d9de4abcdf637e7f40a59022ad8bb2 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 18:52:20 +0100 Subject: [PATCH 064/167] [new_base/parse] Fix missing doc link --- src/new_base/parse/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 493542b66..c4bba79e4 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -154,6 +154,8 @@ impl<'a> ParseBytes<'a> for U32 { /// documented on [`split_bytes_by_ref()`]. An incorrect implementation is /// considered to cause undefined behaviour. /// +/// [`split_bytes_by_ref()`]: Self::split_bytes_by_ref() +/// /// Note that [`ParseBytesByRef`], required by this trait, also has several /// invariants that need to be considered with care. pub unsafe trait SplitBytesByRef: ParseBytesByRef { From 42d48e007b0fe64218bb454da50131762dc8fb61 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 18:52:34 +0100 Subject: [PATCH 065/167] [macros] Fix no-fields output for 'ParseBytes' --- macros/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 605cac3be..67285d420 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -160,7 +160,7 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { if bytes.is_empty() { Ok(Self #fields) } else { - Err() + Err(::domain::new_base::parse::ParseError) } } } From 90b531a2a540e6945a279e37558944c52f5e7f55 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 18:53:00 +0100 Subject: [PATCH 066/167] [new_base/serial] Support measuring unix time --- src/new_base/serial.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/new_base/serial.rs b/src/new_base/serial.rs index eaccf32f2..4258c4b22 100644 --- a/src/new_base/serial.rs +++ b/src/new_base/serial.rs @@ -32,6 +32,21 @@ use domain_macros::*; #[repr(transparent)] pub struct Serial(U32); +//--- Construction + +impl Serial { + /// Measure the current time (in seconds) in serial number space. + #[cfg(feature = "std")] + pub fn unix_time() -> Self { + use std::time::SystemTime; + + let time = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("The current time is after the Unix Epoch"); + Self::from(time.as_secs() as u32) + } +} + //--- Addition impl Add for Serial { From f8fc52583788b4ba1ad053bdd0e126079aaffa4b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 2 Jan 2025 18:53:13 +0100 Subject: [PATCH 067/167] [new_edns] Implement DNS cookie support --- src/new_edns/cookie.rs | 242 +++++++++++++++++++++++++++++++++++++++++ src/new_edns/mod.rs | 16 ++- 2 files changed, 252 insertions(+), 6 deletions(-) create mode 100644 src/new_edns/cookie.rs diff --git a/src/new_edns/cookie.rs b/src/new_edns/cookie.rs new file mode 100644 index 000000000..466e8c606 --- /dev/null +++ b/src/new_edns/cookie.rs @@ -0,0 +1,242 @@ +//! DNS cookies. +//! +//! See [RFC 7873] and [RFC 9018]. +//! +//! [RFC 7873]: https://datatracker.ietf.org/doc/html/rfc7873 +//! [RFC 9018]: https://datatracker.ietf.org/doc/html/rfc9018 + +use core::fmt; + +#[cfg(all(feature = "std", feature = "siphasher"))] +use core::ops::Range; + +#[cfg(all(feature = "std", feature = "siphasher"))] +use std::net::IpAddr; + +use domain_macros::*; + +use crate::new_base::Serial; + +#[cfg(all(feature = "std", feature = "siphasher"))] +use crate::new_base::build::{AsBytes, TruncationError}; + +//----------- CookieRequest -------------------------------------------------- + +/// A request for a DNS cookie. +#[derive( + Copy, + Clone, + PartialEq, + Eq, + Hash, + AsBytes, + BuildBytes, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, +)] +#[repr(transparent)] +pub struct CookieRequest { + /// The octets of the request. + pub octets: [u8; 8], +} + +//--- Construction + +impl CookieRequest { + /// Construct a random [`CookieRequest`]. + #[cfg(feature = "rand")] + pub fn random() -> Self { + rand::random::<[u8; 8]>().into() + } +} + +//--- Interaction + +impl CookieRequest { + /// Build a [`Cookie`] in response to this request. + /// + /// A 24-byte version-1 interoperable cookie will be generated and written + /// to the given buffer. If the buffer is big enough, the remaining part + /// of the buffer is returned. + #[cfg(all(feature = "std", feature = "siphasher"))] + pub fn respond_into<'b>( + &self, + addr: IpAddr, + secret: &[u8; 16], + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + use core::hash::Hasher; + + use siphasher::sip::SipHasher24; + + use crate::new_base::build::BuildBytes; + + // Build and hash the cookie simultaneously. + let mut hasher = SipHasher24::new_with_key(secret); + + bytes = self.build_bytes(bytes)?; + hasher.write(self.as_bytes()); + + // The version number and the reserved octets. + bytes = [1, 0, 0, 0].build_bytes(bytes)?; + hasher.write(&[1, 0, 0, 0]); + + let timestamp = Serial::unix_time(); + bytes = timestamp.build_bytes(bytes)?; + hasher.write(timestamp.as_bytes()); + + match addr { + IpAddr::V4(addr) => hasher.write(&addr.octets()), + IpAddr::V6(addr) => hasher.write(&addr.octets()), + } + + let hash = hasher.finish().to_le_bytes(); + bytes = hash.build_bytes(bytes)?; + + Ok(bytes) + } +} + +//--- Conversion to and from octets + +impl From<[u8; 8]> for CookieRequest { + fn from(value: [u8; 8]) -> Self { + Self { octets: value } + } +} + +impl From for [u8; 8] { + fn from(value: CookieRequest) -> Self { + value.octets + } +} + +//--- Formatting + +impl fmt::Debug for CookieRequest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "CookieRequest({})", self) + } +} + +impl fmt::Display for CookieRequest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:016X}", u64::from_be_bytes(self.octets)) + } +} + +//----------- Cookie --------------------------------------------------------- + +/// A DNS cookie. +#[derive(PartialEq, Eq, Hash, AsBytes, BuildBytes, ParseBytesByRef)] +#[repr(C)] +pub struct Cookie { + /// The request for this cookie. + request: CookieRequest, + + /// The version number of this cookie. + version: u8, + + /// Reserved bytes in the cookie format. + reversed: [u8; 3], + + /// When this cookie was made. + timestamp: Serial, + + /// The hash of this cookie. + hash: [u8], +} + +//--- Inspection + +impl Cookie { + /// The underlying cookie request. + pub fn request(&self) -> &CookieRequest { + &self.request + } + + /// The version number of this interoperable cookie. + /// + /// Assuming this is an interoperable cookie, as specified by [RFC 9018], + /// the 1-byte version number of the cookie is returned. Currently, only + /// version 1 has been specified. + /// + /// [RFC 9018]: https://datatracker.ietf.org/doc/html/rfc9018 + pub fn version(&self) -> u8 { + self.version + } + + /// When this interoperable cookie was produced. + /// + /// Assuming this is an interoperable cookie, as specified by [RFC 9018], + /// the 4-byte timestamp of the cookie is returned. + /// + /// [RFC 9018]: https://datatracker.ietf.org/doc/html/rfc9018 + pub fn timestamp(&self) -> Serial { + self.timestamp + } +} + +//--- Interaction + +impl Cookie { + /// Verify this cookie. + /// + /// This cookie is verified as a 24-byte version-1 interoperable cookie, + /// as specified by [RFC 9018]. A 16-byte secret is used to generate a + /// hash for this cookie, based on its fields and the IP address of the + /// client which used it. If the cookie was generated in the given time + /// period, and the generated hash matches the hash in the cookie, it is + /// valid. + /// + /// [RFC 9018]: https://datatracker.ietf.org/doc/html/rfc9018 + #[cfg(all(feature = "std", feature = "siphasher"))] + pub fn verify( + &self, + addr: IpAddr, + secret: &[u8; 16], + validity: Range, + ) -> Result<(), CookieError> { + use core::hash::Hasher; + + use siphasher::sip::SipHasher24; + + // Check basic features of the cookie. + if self.version != 1 + || self.hash.len() != 8 + || !validity.contains(&self.timestamp) + { + return Err(CookieError); + } + + // Check the cookie hash. + let mut hasher = SipHasher24::new_with_key(secret); + hasher.write(&self.as_bytes()[..16]); + match addr { + IpAddr::V4(addr) => hasher.write(&addr.octets()), + IpAddr::V6(addr) => hasher.write(&addr.octets()), + } + + if self.hash == hasher.finish().to_le_bytes() { + Ok(()) + } else { + Err(CookieError) + } + } +} + +//----------- CookieError ---------------------------------------------------- + +/// An invalid [`Cookie`] was encountered. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct CookieError; + +//--- Formatting + +impl fmt::Display for CookieError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("A DNS cookie could not be verified") + } +} diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index d5a1d366f..e72529fb3 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -5,7 +5,6 @@ use core::{fmt, ops::Range}; use zerocopy::{network_endian::U16, IntoBytes}; -use zerocopy_derive::*; use domain_macros::*; @@ -20,6 +19,11 @@ use crate::{ new_rdata::Opt, }; +//----------- EDNS option modules -------------------------------------------- + +mod cookie; +pub use cookie::{Cookie, CookieRequest}; + //----------- EdnsRecord ----------------------------------------------------- /// An Extended DNS record. @@ -137,8 +141,8 @@ impl<'a> ParseBytes<'a> for EdnsRecord<'a> { Clone, Default, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -217,8 +221,8 @@ pub enum EdnsOption<'b> { PartialOrd, Ord, Hash, - IntoBytes, - Immutable, + AsBytes, + BuildBytes, ParseBytes, ParseBytesByRef, SplitBytes, @@ -233,7 +237,7 @@ pub struct OptionCode { //----------- UnknownOption -------------------------------------------------- /// Data for an unknown Extended DNS option. -#[derive(Debug, IntoBytes, Immutable, ParseBytesByRef)] +#[derive(Debug, AsBytes, BuildBytes, ParseBytesByRef)] #[repr(C)] pub struct UnknownOption { /// The unparsed option data. From d4455874ec2ee7ea416549643a6b28e53e56f96d Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 3 Jan 2025 12:13:40 +0100 Subject: [PATCH 068/167] [macros] Factor out struct inspection and building --- macros/src/data.rs | 159 +++++++++++++++++++ macros/src/impls.rs | 16 ++ macros/src/lib.rs | 375 ++++++++++++++++++-------------------------- 3 files changed, 326 insertions(+), 224 deletions(-) create mode 100644 macros/src/data.rs diff --git a/macros/src/data.rs b/macros/src/data.rs new file mode 100644 index 000000000..6a0788b3e --- /dev/null +++ b/macros/src/data.rs @@ -0,0 +1,159 @@ +//! Working with structs, enums, and unions. + +use std::ops::Deref; + +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::{spanned::Spanned, *}; + +//----------- Struct --------------------------------------------------------- + +/// A defined 'struct'. +pub struct Struct { + /// The identifier for this 'struct'. + ident: Ident, + + /// The fields in this 'struct'. + fields: Fields, +} + +impl Struct { + /// Construct a [`Struct`] for a 'Self'. + pub fn new_as_self(fields: &Fields) -> Self { + Self { + ident: ::default().into(), + fields: fields.clone(), + } + } + + /// Whether this 'struct' has no fields. + pub fn is_empty(&self) -> bool { + self.fields.is_empty() + } + + /// The number of fields in this 'struct'. + pub fn num_fields(&self) -> usize { + self.fields.len() + } + + /// The fields of this 'struct'. + pub fn fields(&self) -> impl Iterator + '_ { + self.fields.iter() + } + + /// The sized fields of this 'struct'. + pub fn sized_fields(&self) -> impl Iterator + '_ { + self.fields().take(self.num_fields() - 1) + } + + /// The unsized field of this 'struct'. + pub fn unsized_field(&self) -> Option<&Field> { + self.fields.iter().next_back() + } + + /// The names of the fields of this 'struct'. + pub fn members(&self) -> impl Iterator + '_ { + self.fields + .iter() + .enumerate() + .map(|(i, f)| make_member(i, f)) + } + + /// The names of the sized fields of this 'struct'. + pub fn sized_members(&self) -> impl Iterator + '_ { + self.members().take(self.num_fields() - 1) + } + + /// The name of the last field of this 'struct'. + pub fn unsized_member(&self) -> Option { + self.fields + .iter() + .next_back() + .map(|f| make_member(self.num_fields() - 1, f)) + } + + /// Construct a builder for this 'struct'. + pub fn builder Ident>( + &self, + f: F, + ) -> StructBuilder<'_, F> { + StructBuilder { + target: self, + var_fn: f, + } + } +} + +/// Construct a [`Member`] from a field and index. +fn make_member(index: usize, field: &Field) -> Member { + match &field.ident { + Some(ident) => Member::Named(ident.clone()), + None => Member::Unnamed(Index { + index: index as u32, + span: field.ty.span(), + }), + } +} + +//----------- StructBuilder -------------------------------------------------- + +/// A means of constructing a 'struct'. +pub struct StructBuilder<'a, F: Fn(Member) -> Ident> { + /// The 'struct' being constructed. + target: &'a Struct, + + /// A map from field names to constructing variables. + var_fn: F, +} + +impl Ident> StructBuilder<'_, F> { + /// The initializing variables for this 'struct'. + pub fn init_vars(&self) -> impl Iterator + '_ { + self.members().map(&self.var_fn) + } + + /// The names of the sized fields of this 'struct'. + pub fn sized_init_vars(&self) -> impl Iterator + '_ { + self.sized_members().map(&self.var_fn) + } + + /// The name of the last field of this 'struct'. + pub fn unsized_init_var(&self) -> Option { + self.unsized_member().map(&self.var_fn) + } +} + +impl Ident> Deref for StructBuilder<'_, F> { + type Target = Struct; + + fn deref(&self) -> &Self::Target { + self.target + } +} + +impl Ident> ToTokens for StructBuilder<'_, F> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ident = &self.ident; + match self.fields { + Fields::Named(_) => { + let members = self.members(); + let init_vars = self.init_vars(); + quote! { + #ident { #(#members: #init_vars),* } + } + } + + Fields::Unnamed(_) => { + let init_vars = self.init_vars(); + quote! { + #ident ( #(#init_vars),* ) + } + } + + Fields::Unit => { + quote! { #ident } + } + } + .to_tokens(tokens); + } +} diff --git a/macros/src/impls.rs b/macros/src/impls.rs index 5e3b884a0..2d9724f0e 100644 --- a/macros/src/impls.rs +++ b/macros/src/impls.rs @@ -186,6 +186,22 @@ impl ImplSkeleton { }) .unwrap() } + + /// Generate a unique lifetime parameter with the given prefix and bounds. + pub fn new_lifetime_param( + &self, + prefix: &str, + bounds: impl IntoIterator, + ) -> (Lifetime, LifetimeParam) { + let lifetime = self.new_lifetime(prefix); + let mut bounds = bounds.into_iter().peekable(); + let param = if bounds.peek().is_some() { + parse_quote! { #lifetime: #(#bounds)+* } + } else { + parse_quote! { #lifetime } + }; + (lifetime, param) + } } impl ToTokens for ImplSkeleton { diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 67285d420..99d209fff 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -4,13 +4,15 @@ use proc_macro as pm; use proc_macro2::TokenStream; -use quote::{format_ident, quote, ToTokens}; -use spanned::Spanned; +use quote::{format_ident, ToTokens}; use syn::*; mod impls; use impls::ImplSkeleton; +mod data; +use data::Struct; + mod repr; use repr::Repr; @@ -39,58 +41,30 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { let mut skeleton = ImplSkeleton::new(&input, false); // Add the parsing lifetime to the 'impl'. - let lifetime = skeleton.new_lifetime("bytes"); + let (lifetime, param) = skeleton.new_lifetime_param( + "bytes", + skeleton.lifetimes.iter().map(|l| l.lifetime.clone()), + ); + skeleton.lifetimes.push(param); skeleton.bound = Some( parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), ); - if !skeleton.lifetimes.is_empty() { - let lifetimes = skeleton.lifetimes.iter(); - let param = parse_quote! { - #lifetime: #(#lifetimes)+* - }; - skeleton.lifetimes.push(param); - } else { - skeleton.lifetimes.push(parse_quote! { #lifetime }) - } + + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + let builder = data.builder(field_prefixed); // Establish bounds on the fields. - for field in data.fields.iter() { + for field in data.fields() { skeleton.require_bound( field.ty.clone(), parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), ); } - // Construct a 'Self' expression. - let self_expr = match &data.fields { - Fields::Named(_) => { - let names = data.fields.members(); - let exprs = - names.clone().map(|n| format_ident!("field_{}", n)); - quote! { - Self { - #(#names: #exprs,)* - } - } - } - - Fields::Unnamed(_) => { - let exprs = data - .fields - .members() - .map(|n| format_ident!("field_{}", n)); - quote! { - Self(#(#exprs,)*) - } - } - - Fields::Unit => quote! { Self }, - }; - // Define 'parse_bytes()'. - let names = - data.fields.members().map(|n| format_ident!("field_{}", n)); - let tys = data.fields.iter().map(|f| &f.ty); + let init_vars = builder.init_vars(); + let tys = data.fields().map(|f| &f.ty); skeleton.contents.stmts.push(parse_quote! { fn split_bytes( bytes: & #lifetime [::domain::__core::primitive::u8], @@ -98,10 +72,10 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { (Self, & #lifetime [::domain::__core::primitive::u8]), ::domain::new_base::parse::ParseError, > { - #(let (#names, bytes) = + #(let (#init_vars, bytes) = <#tys as ::domain::new_base::parse::SplitBytes<#lifetime>> ::split_bytes(bytes)?;)* - Ok((#self_expr, bytes)) + Ok((#builder, bytes)) } }); @@ -135,106 +109,62 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { } }; - // Split up the last field from the rest. - let mut fields = data.fields.iter(); - let Some(last) = fields.next_back() else { - // This type has no fields. Return a simple implementation. - assert!(input.generics.params.is_empty()); - let where_clause = input.generics.where_clause; - let name = input.ident; - - // This will tokenize to '{}', '()', or ''. - let fields = data.fields.to_token_stream(); - - return Ok(quote! { - impl <'bytes> - ::domain::new_base::parse::ParseBytes<'bytes> - for #name - #where_clause { - fn parse_bytes( - bytes: &'bytes [::domain::__core::primitive::u8], - ) -> ::domain::__core::result::Result< - Self, - ::domain::new_base::parse::ParseError, - > { - if bytes.is_empty() { - Ok(Self #fields) - } else { - Err(::domain::new_base::parse::ParseError) - } - } - } - }); - }; - // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, false); // Add the parsing lifetime to the 'impl'. - let lifetime = skeleton.new_lifetime("bytes"); + let (lifetime, param) = skeleton.new_lifetime_param( + "bytes", + skeleton.lifetimes.iter().map(|l| l.lifetime.clone()), + ); + skeleton.lifetimes.push(param); skeleton.bound = Some( parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), ); - if !skeleton.lifetimes.is_empty() { - let lifetimes = skeleton.lifetimes.iter(); - let param = parse_quote! { - #lifetime: #(#lifetimes)+* - }; - skeleton.lifetimes.push(param); - } else { - skeleton.lifetimes.push(parse_quote! { #lifetime }) - } + + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + let builder = data.builder(field_prefixed); // Establish bounds on the fields. - for field in fields.clone() { - // This field should implement 'SplitBytes'. + for field in data.sized_fields() { skeleton.require_bound( field.ty.clone(), parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), ); } - // The last field should implement 'ParseBytes'. - skeleton.require_bound( - last.ty.clone(), - parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), - ); + if let Some(field) = data.unsized_field() { + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), + ); + } - // Construct a 'Self' expression. - let self_expr = match &data.fields { - Fields::Named(_) => { - let names = data.fields.members(); - let exprs = - names.clone().map(|n| format_ident!("field_{}", n)); - quote! { - Self { - #(#names: #exprs,)* + // Finish early if the 'struct' has no fields. + if data.is_empty() { + skeleton.contents.stmts.push(parse_quote! { + fn parse_bytes( + bytes: & #lifetime [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + Self, + ::domain::new_base::parse::ParseError, + > { + if bytes.is_empty() { + Ok(#builder) + } else { + Err(::domain::new_base::parse::ParseError) } } - } - - Fields::Unnamed(_) => { - let exprs = data - .fields - .members() - .map(|n| format_ident!("field_{}", n)); - quote! { - Self(#(#exprs,)*) - } - } + }); - Fields::Unit => unreachable!(), - }; + return Ok(skeleton.into_token_stream()); + } // Define 'parse_bytes()'. - let names = data - .fields - .members() - .take(fields.len()) - .map(|n| format_ident!("field_{}", n)); - let tys = fields.clone().map(|f| &f.ty); - let last_ty = &last.ty; - let last_name = - format_ident!("field_{}", data.fields.members().last().unwrap()); + let init_vars = builder.sized_init_vars(); + let tys = builder.sized_fields().map(|f| &f.ty); + let unsized_ty = &builder.unsized_field().unwrap().ty; + let unsized_init_var = builder.unsized_init_var().unwrap(); skeleton.contents.stmts.push(parse_quote! { fn parse_bytes( bytes: & #lifetime [::domain::__core::primitive::u8], @@ -242,13 +172,13 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { Self, ::domain::new_base::parse::ParseError, > { - #(let (#names, bytes) = + #(let (#init_vars, bytes) = <#tys as ::domain::new_base::parse::SplitBytes<#lifetime>> ::split_bytes(bytes)?;)* - let #last_name = - <#last_ty as ::domain::new_base::parse::ParseBytes<#lifetime>> + let #unsized_init_var = + <#unsized_ty as ::domain::new_base::parse::ParseBytes<#lifetime>> ::parse_bytes(bytes)?; - Ok(#self_expr) + Ok(#builder) } }); @@ -284,50 +214,47 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { let _ = Repr::determine(&input.attrs, "SplitBytesByRef")?; - // Split up the last field from the rest. - let mut fields = data.fields.iter(); - let Some(last) = fields.next_back() else { - // This type has no fields. Return a simple implementation. - let (impl_generics, ty_generics, where_clause) = - input.generics.split_for_impl(); - let name = input.ident; - - return Ok(quote! { - unsafe impl #impl_generics - ::domain::new_base::parse::SplitBytesByRef - for #name #ty_generics - #where_clause { - fn split_bytes_by_ref( - bytes: &[::domain::__core::primitive::u8], - ) -> ::domain::__core::result::Result< - (&Self, &[::domain::__core::primitive::u8]), - ::domain::new_base::parse::ParseError, - > { - Ok(( - unsafe { &*bytes.as_ptr().cast::() }, - bytes, - )) - } - } - }); - }; - // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); skeleton.bound = Some(parse_quote!(::domain::new_base::parse::SplitBytesByRef)); + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + // Establish bounds on the fields. - for field in data.fields.iter() { + for field in data.fields() { skeleton.require_bound( field.ty.clone(), parse_quote!(::domain::new_base::parse::SplitBytesByRef), ); } + // Finish early if the 'struct' has no fields. + if data.is_empty() { + skeleton.contents.stmts.push(parse_quote! { + fn split_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + (&Self, &[::domain::__core::primitive::u8]), + ::domain::new_base::parse::ParseError, + > { + Ok(( + // SAFETY: 'Self' is a 'struct' with no fields, + // and so has size 0 and alignment 1. It can be + // constructed at any address. + unsafe { &*bytes.as_ptr().cast::() }, + bytes, + )) + } + }); + + return Ok(skeleton.into_token_stream()); + } + // Define 'split_bytes_by_ref()'. - let tys = fields.clone().map(|f| &f.ty); - let last_ty = &last.ty; + let tys = data.sized_fields().map(|f| &f.ty); + let unsized_ty = &data.unsized_field().unwrap().ty; skeleton.contents.stmts.push(parse_quote! { fn split_bytes_by_ref( bytes: &[::domain::__core::primitive::u8], @@ -340,10 +267,10 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { <#tys as ::domain::new_base::parse::SplitBytesByRef> ::split_bytes_by_ref(bytes)?;)* let (last, rest) = - <#last_ty as ::domain::new_base::parse::SplitBytesByRef> + <#unsized_ty as ::domain::new_base::parse::SplitBytesByRef> ::split_bytes_by_ref(bytes)?; let ptr = - <#last_ty as ::domain::new_base::parse::ParseBytesByRef> + <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> ::ptr_with_address(last, start as *const ()); // SAFETY: @@ -392,67 +319,63 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { let _ = Repr::determine(&input.attrs, "ParseBytesByRef")?; - // Split up the last field from the rest. - let mut fields = data.fields.iter(); - let Some(last) = fields.next_back() else { - // This type has no fields. Return a simple implementation. - let (impl_generics, ty_generics, where_clause) = - input.generics.split_for_impl(); - let name = input.ident; - - return Ok(quote! { - unsafe impl #impl_generics - ::domain::new_base::parse::ParseBytesByRef - for #name #ty_generics - #where_clause { - fn parse_bytes_by_ref( - bytes: &[::domain::__core::primitive::u8], - ) -> ::domain::__core::result::Result< - &Self, - ::domain::new_base::parse::ParseError, - > { - if bytes.is_empty() { - // SAFETY: 'Self' is a 'struct' with no fields, - // and so has size 0 and alignment 1. It can be - // constructed at any address. - Ok(unsafe { &*bytes.as_ptr().cast::() }) - } else { - Err(::domain::new_base::parse::ParseError) - } - } - - fn ptr_with_address( - &self, - addr: *const (), - ) -> *const Self { - addr.cast() - } - } - }); - }; - // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); skeleton.bound = Some(parse_quote!(::domain::new_base::parse::ParseBytesByRef)); + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + // Establish bounds on the fields. - for field in fields.clone() { - // This field should implement 'SplitBytesByRef'. + for field in data.sized_fields() { skeleton.require_bound( field.ty.clone(), parse_quote!(::domain::new_base::parse::SplitBytesByRef), ); } - // The last field should implement 'ParseBytesByRef'. - skeleton.require_bound( - last.ty.clone(), - parse_quote!(::domain::new_base::parse::ParseBytesByRef), - ); + if let Some(field) = data.unsized_field() { + skeleton.require_bound( + field.ty.clone(), + parse_quote!(::domain::new_base::parse::ParseBytesByRef), + ); + } + + // Finish early if the 'struct' has no fields. + if data.is_empty() { + skeleton.contents.stmts.push(parse_quote! { + fn parse_bytes_by_ref( + bytes: &[::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + &Self, + ::domain::new_base::parse::ParseError, + > { + if bytes.is_empty() { + // SAFETY: 'Self' is a 'struct' with no fields, + // and so has size 0 and alignment 1. It can be + // constructed at any address. + Ok(unsafe { &*bytes.as_ptr().cast::() }) + } else { + Err(::domain::new_base::parse::ParseError) + } + } + }); + + skeleton.contents.stmts.push(parse_quote! { + fn ptr_with_address( + &self, + addr: *const (), + ) -> *const Self { + addr.cast() + } + }); + + return Ok(skeleton.into_token_stream()); + } // Define 'parse_bytes_by_ref()'. - let tys = fields.clone().map(|f| &f.ty); - let last_ty = &last.ty; + let tys = data.sized_fields().map(|f| &f.ty); + let unsized_ty = &data.unsized_field().unwrap().ty; skeleton.contents.stmts.push(parse_quote! { fn parse_bytes_by_ref( bytes: &[::domain::__core::primitive::u8], @@ -465,10 +388,10 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { <#tys as ::domain::new_base::parse::SplitBytesByRef> ::split_bytes_by_ref(bytes)?;)* let last = - <#last_ty as ::domain::new_base::parse::ParseBytesByRef> + <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> ::parse_bytes_by_ref(bytes)?; let ptr = - <#last_ty as ::domain::new_base::parse::ParseBytesByRef> + <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> ::ptr_with_address(last, start as *const ()); // SAFETY: @@ -486,17 +409,11 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { }); // Define 'ptr_with_address()'. - let last_name = match last.ident.as_ref() { - Some(ident) => Member::Named(ident.clone()), - None => Member::Unnamed(Index { - index: data.fields.len() as u32 - 1, - span: last.ty.span(), - }), - }; + let unsized_member = data.unsized_member(); skeleton.contents.stmts.push(parse_quote! { fn ptr_with_address(&self, addr: *const ()) -> *const Self { - <#last_ty as ::domain::new_base::parse::ParseBytesByRef> - ::ptr_with_address(&self.#last_name, addr) + <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> + ::ptr_with_address(&self.#unsized_member, addr) as *const Self } }); @@ -536,11 +453,14 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { skeleton.bound = Some(parse_quote!(::domain::new_base::build::BuildBytes)); + // Inspect the 'struct' fields. + let data = Struct::new_as_self(&data.fields); + // Get a lifetime for the input buffer. let lifetime = skeleton.new_lifetime("bytes"); // Establish bounds on the fields. - for field in data.fields.iter() { + for field in data.fields() { skeleton.require_bound( field.ty.clone(), parse_quote!(::domain::new_base::build::BuildBytes), @@ -548,8 +468,8 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { } // Define 'build_bytes()'. - let names = data.fields.members(); - let tys = data.fields.iter().map(|f| &f.ty); + let members = data.members(); + let tys = data.fields().map(|f| &f.ty); skeleton.contents.stmts.push(parse_quote! { fn build_bytes<#lifetime>( &self, @@ -559,7 +479,7 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { ::domain::new_base::build::TruncationError, > { #(bytes = <#tys as ::domain::new_base::build::BuildBytes> - ::build_bytes(&self.#names, bytes)?;)* + ::build_bytes(&self.#members, bytes)?;)* Ok(bytes) } }); @@ -619,3 +539,10 @@ pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { .unwrap_or_else(syn::Error::into_compile_error) .into() } + +//----------- Utility Functions ---------------------------------------------- + +/// Add a `field_` prefix to member names. +fn field_prefixed(member: Member) -> Ident { + format_ident!("field_{}", member) +} From 21dfd3daa6a612b7b662af54736ab7a44316342f Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 6 Jan 2025 14:00:43 +0100 Subject: [PATCH 069/167] [new_edns] Impl RFC 8914 "Extended DNS errors" --- src/new_base/parse/mod.rs | 17 +++ src/new_edns/ext_err.rs | 210 ++++++++++++++++++++++++++++++++++++++ src/new_edns/mod.rs | 3 + 3 files changed, 230 insertions(+) create mode 100644 src/new_edns/ext_err.rs diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index c4bba79e4..32bc3627a 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -321,6 +321,23 @@ unsafe impl ParseBytesByRef for [u8] { } } +unsafe impl ParseBytesByRef for str { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + core::str::from_utf8(bytes).map_err(|_| ParseError) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + // NOTE: The Rust Reference indicates that 'str' has the same layout + // as '[u8]' [1]. This is also the most natural layout for it. Since + // there's no way to construct a '*const str' from raw parts, we will + // just construct a raw slice and transmute it. + // + // [1]: https://doc.rust-lang.org/reference/type-layout.html#str-layout + + self.as_bytes().ptr_with_address(addr) as *const Self + } +} + unsafe impl SplitBytesByRef for [T; N] { fn split_bytes_by_ref( mut bytes: &[u8], diff --git a/src/new_edns/ext_err.rs b/src/new_edns/ext_err.rs new file mode 100644 index 000000000..030df6814 --- /dev/null +++ b/src/new_edns/ext_err.rs @@ -0,0 +1,210 @@ +//! Extended DNS errors. +//! +//! See [RFC 8914](https://datatracker.ietf.org/doc/html/rfc8914). + +use core::fmt; + +use domain_macros::*; + +use zerocopy::network_endian::U16; + +//----------- ExtError ------------------------------------------------------- + +/// An extended DNS error. +#[derive(ParseBytesByRef)] +#[repr(C)] +pub struct ExtError { + /// The error code. + pub code: ExtErrorCode, + + /// A human-readable description of the error. + text: str, +} + +impl ExtError { + /// A human-readable description of the error. + pub fn text(&self) -> Option<&str> { + if !self.text.is_empty() { + Some(self.text.strip_suffix('\0').unwrap_or(&self.text)) + } else { + None + } + } +} + +//----------- ExtErrorCode --------------------------------------------------- + +/// The code for an extended DNS error. +#[derive( + Copy, + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + AsBytes, + BuildBytes, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, +)] +#[repr(transparent)] +pub struct ExtErrorCode { + inner: U16, +} + +//--- Associated Constants + +impl ExtErrorCode { + const fn new(inner: u16) -> Self { + Self { + inner: U16::new(inner), + } + } + + /// An unspecified extended error. + /// + /// This should be used when there is no other appropriate error code. + pub const OTHER: Self = Self::new(0); + + /// DNSSEC validation failed because a DNSKEY used an unknown algorithm. + pub const BAD_DNSKEY_ALG: Self = Self::new(1); + + /// DNSSEC validation failed because a DS set used an unknown algorithm. + pub const BAD_DS_ALG: Self = Self::new(2); + + /// An up-to-date answer could not be retrieved in time. + pub const STALE_ANSWER: Self = Self::new(3); + + /// Policy dictated that a forged answer be returned. + pub const FORGED_ANSWER: Self = Self::new(4); + + /// The DNSSEC validity of the answer could not be determined. + pub const DNSSEC_INDETERMINATE: Self = Self::new(5); + + /// The answer was invalid as per DNSSEC. + pub const DNSSEC_BOGUS: Self = Self::new(6); + + /// The DNSSEC signature of the answer expired. + pub const SIG_EXPIRED: Self = Self::new(7); + + /// The DNSSEC signature of the answer is valid in the future. + pub const SIG_FUTURE: Self = Self::new(8); + + /// DNSSEC validation failed because a DNSKEY record was missing. + pub const DNSKEY_MISSING: Self = Self::new(9); + + /// DNSSEC validation failed because RRSIGs were unexpectedly missing. + pub const RRSIGS_MISSING: Self = Self::new(10); + + /// DNSSEC validation failed because a DNSKEY wasn't a ZSK. + pub const NOT_ZSK: Self = Self::new(11); + + /// DNSSEC validation failed because an NSEC(3) record could not be found. + pub const NSEC_MISSING: Self = Self::new(12); + + /// The server failure error was cached from an upstream. + pub const CACHED_ERROR: Self = Self::new(13); + + /// The server is not ready to serve requests. + pub const NOT_READY: Self = Self::new(14); + + /// The request is blocked by internal policy. + pub const BLOCKED: Self = Self::new(15); + + /// The request is blocked by external policy. + pub const CENSORED: Self = Self::new(16); + + /// The request is blocked by the client's own filters. + pub const FILTERED: Self = Self::new(17); + + /// The client is prohibited from making requests. + pub const PROHIBITED: Self = Self::new(18); + + /// An up-to-date answer could not be retrieved in time. + pub const STALE_NXDOMAIN: Self = Self::new(19); + + /// The request cannot be answered authoritatively. + pub const NOT_AUTHORITATIVE: Self = Self::new(20); + + /// The request / operation is not supported. + pub const NOT_SUPPORTED: Self = Self::new(21); + + /// No upstream authorities answered the request (in time). + pub const NO_REACHABLE_AUTHORITY: Self = Self::new(22); + + /// An unrecoverable network error occurred. + pub const NETWORK_ERROR: Self = Self::new(23); + + /// The server's local zone data is invalid. + pub const INVALID_DATA: Self = Self::new(24); + + /// An impure operation was stated in a DNS-over-QUIC 0-RTT packet. + /// + /// See [RFC 9250](https://datatracker.ietf.org/doc/html/rfc9250). + pub const TOO_EARLY: Self = Self::new(26); + + /// DNSSEC validation failed because an NSEC3 parameter was unsupported. + pub const BAD_NSEC3_ITERS: Self = Self::new(27); +} + +//--- Inspection + +impl ExtErrorCode { + /// Whether this is a private-use code. + /// + /// Private-use codes occupy the range 49152 to 65535 (inclusive). + pub fn is_private(&self) -> bool { + self.inner >= 49152 + } +} + +//--- Formatting + +impl fmt::Debug for ExtErrorCode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let text = match *self { + Self::OTHER => "other", + Self::BAD_DNSKEY_ALG => "unsupported DNSKEY algorithm", + Self::BAD_DS_ALG => "unspported DS digest type", + Self::STALE_ANSWER => "stale answer", + Self::FORGED_ANSWER => "forged answer", + Self::DNSSEC_INDETERMINATE => "DNSSEC indeterminate", + Self::DNSSEC_BOGUS => "DNSSEC bogus", + Self::SIG_EXPIRED => "signature expired", + Self::SIG_FUTURE => "signature not yet valid", + Self::DNSKEY_MISSING => "DNSKEY missing", + Self::RRSIGS_MISSING => "RRSIGs missing", + Self::NOT_ZSK => "no zone key bit set", + Self::NSEC_MISSING => "nsec missing", + Self::CACHED_ERROR => "cached error", + Self::NOT_READY => "not ready", + Self::BLOCKED => "blocked", + Self::CENSORED => "censored", + Self::FILTERED => "filtered", + Self::PROHIBITED => "prohibited", + Self::STALE_NXDOMAIN => "stale NXDOMAIN answer", + Self::NOT_AUTHORITATIVE => "not authoritative", + Self::NOT_SUPPORTED => "not supported", + Self::NO_REACHABLE_AUTHORITY => "no reachable authority", + Self::NETWORK_ERROR => "network error", + Self::INVALID_DATA => "invalid data", + Self::TOO_EARLY => "too early", + Self::BAD_NSEC3_ITERS => "unsupported NSEC3 iterations value", + + _ => { + return f + .debug_tuple("ExtErrorCode") + .field(&self.inner.get()) + .finish(); + } + }; + + f.debug_tuple("ExtErrorCode") + .field(&self.inner.get()) + .field(&text) + .finish() + } +} diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index e72529fb3..b14170919 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -24,6 +24,9 @@ use crate::{ mod cookie; pub use cookie::{Cookie, CookieRequest}; +mod ext_err; +pub use ext_err::{ExtError, ExtErrorCode}; + //----------- EdnsRecord ----------------------------------------------------- /// An Extended DNS record. From 16596213071f965c09c17a5053ae5e44a5ebb043 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 6 Jan 2025 14:42:33 +0100 Subject: [PATCH 070/167] [new_edns] Impl parsing/building for 'EdnsOption' --- src/new_base/build/mod.rs | 10 +++ src/new_edns/cookie.rs | 4 +- src/new_edns/ext_err.rs | 26 +++++-- src/new_edns/mod.rs | 144 +++++++++++++++++++++++++++++++++++++- src/new_rdata/mod.rs | 2 +- 5 files changed, 175 insertions(+), 11 deletions(-) diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 56670e922..548b2d8fd 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -83,6 +83,15 @@ impl BuildBytes for u8 { } } +impl BuildBytes for str { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.as_bytes().build_bytes(bytes) + } +} + impl BuildBytes for U16 { fn build_bytes<'b>( &self, @@ -158,6 +167,7 @@ pub unsafe trait AsBytes { } unsafe impl AsBytes for u8 {} +unsafe impl AsBytes for str {} unsafe impl AsBytes for [T] {} unsafe impl AsBytes for [T; N] {} diff --git a/src/new_edns/cookie.rs b/src/new_edns/cookie.rs index 466e8c606..1a815e615 100644 --- a/src/new_edns/cookie.rs +++ b/src/new_edns/cookie.rs @@ -130,7 +130,9 @@ impl fmt::Display for CookieRequest { //----------- Cookie --------------------------------------------------------- /// A DNS cookie. -#[derive(PartialEq, Eq, Hash, AsBytes, BuildBytes, ParseBytesByRef)] +#[derive( + Debug, PartialEq, Eq, Hash, AsBytes, BuildBytes, ParseBytesByRef, +)] #[repr(C)] pub struct Cookie { /// The request for this cookie. diff --git a/src/new_edns/ext_err.rs b/src/new_edns/ext_err.rs index 030df6814..6858aa713 100644 --- a/src/new_edns/ext_err.rs +++ b/src/new_edns/ext_err.rs @@ -11,7 +11,7 @@ use zerocopy::network_endian::U16; //----------- ExtError ------------------------------------------------------- /// An extended DNS error. -#[derive(ParseBytesByRef)] +#[derive(AsBytes, ParseBytesByRef)] #[repr(C)] pub struct ExtError { /// The error code. @@ -32,6 +32,17 @@ impl ExtError { } } +//--- Formatting + +impl fmt::Debug for ExtError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ExtError") + .field("code", &self.code) + .field("text", &self.text()) + .finish() + } +} + //----------- ExtErrorCode --------------------------------------------------- /// The code for an extended DNS error. @@ -52,15 +63,16 @@ impl ExtError { )] #[repr(transparent)] pub struct ExtErrorCode { - inner: U16, + /// The error code. + pub code: U16, } //--- Associated Constants impl ExtErrorCode { - const fn new(inner: u16) -> Self { + const fn new(code: u16) -> Self { Self { - inner: U16::new(inner), + code: U16::new(code), } } @@ -157,7 +169,7 @@ impl ExtErrorCode { /// /// Private-use codes occupy the range 49152 to 65535 (inclusive). pub fn is_private(&self) -> bool { - self.inner >= 49152 + self.code >= 49152 } } @@ -197,13 +209,13 @@ impl fmt::Debug for ExtErrorCode { _ => { return f .debug_tuple("ExtErrorCode") - .field(&self.inner.get()) + .field(&self.code.get()) .finish(); } }; f.debug_tuple("ExtErrorCode") - .field(&self.inner.get()) + .field(&self.code.get()) .field(&text) .finish() } diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index b14170919..3b3ab2a80 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -10,6 +10,7 @@ use domain_macros::*; use crate::{ new_base::{ + build::{AsBytes, BuildBytes, TruncationError}, parse::{ ParseBytes, ParseBytesByRef, ParseError, ParseFromMessage, SplitBytes, SplitFromMessage, @@ -208,17 +209,123 @@ impl fmt::Debug for EdnsFlags { #[derive(Debug)] #[non_exhaustive] pub enum EdnsOption<'b> { + /// A request for a DNS cookie. + CookieRequest(&'b CookieRequest), + + /// A DNS cookie. + Cookie(&'b Cookie), + + /// An extended DNS error. + ExtError(&'b ExtError), + /// An unknown option. Unknown(OptionCode, &'b UnknownOption), } +//--- Inspection + +impl EdnsOption<'_> { + /// The code for this option. + pub fn code(&self) -> OptionCode { + match self { + Self::CookieRequest(_) => OptionCode::COOKIE, + Self::Cookie(_) => OptionCode::COOKIE, + Self::ExtError(_) => OptionCode::EXT_ERROR, + Self::Unknown(code, _) => *code, + } + } +} + +//--- Parsing from bytes + +impl<'b> ParseBytes<'b> for EdnsOption<'b> { + fn parse_bytes(bytes: &'b [u8]) -> Result { + let (code, rest) = OptionCode::split_bytes(bytes)?; + let (size, rest) = U16::split_bytes(rest)?; + if rest.len() != size.get() as usize { + return Err(ParseError); + } + + match code { + OptionCode::COOKIE => match size.get() { + 8 => CookieRequest::parse_bytes_by_ref(rest) + .map(Self::CookieRequest), + 16..=40 => Cookie::parse_bytes_by_ref(rest).map(Self::Cookie), + _ => Err(ParseError), + }, + + OptionCode::EXT_ERROR => { + ExtError::parse_bytes_by_ref(rest).map(Self::ExtError) + } + + _ => { + let data = UnknownOption::parse_bytes_by_ref(rest)?; + Ok(Self::Unknown(code, data)) + } + } + } +} + +impl<'b> SplitBytes<'b> for EdnsOption<'b> { + fn split_bytes(bytes: &'b [u8]) -> Result<(Self, &'b [u8]), ParseError> { + let (code, rest) = OptionCode::split_bytes(bytes)?; + let (size, rest) = U16::split_bytes(rest)?; + if rest.len() < size.get() as usize { + return Err(ParseError); + } + let (bytes, rest) = rest.split_at(size.get() as usize); + + match code { + OptionCode::COOKIE => match size.get() { + 8 => CookieRequest::parse_bytes_by_ref(bytes) + .map(Self::CookieRequest), + 16..=40 => { + Cookie::parse_bytes_by_ref(bytes).map(Self::Cookie) + } + _ => Err(ParseError), + }, + + OptionCode::EXT_ERROR => { + ExtError::parse_bytes_by_ref(bytes).map(Self::ExtError) + } + + _ => { + let data = UnknownOption::parse_bytes_by_ref(bytes)?; + Ok(Self::Unknown(code, data)) + } + } + .map(|this| (this, rest)) + } +} + +//--- Building byte strings + +impl BuildBytes for EdnsOption<'_> { + fn build_bytes<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + bytes = self.code().build_bytes(bytes)?; + + let data = match self { + Self::CookieRequest(this) => this.as_bytes(), + Self::Cookie(this) => this.as_bytes(), + Self::ExtError(this) => this.as_bytes(), + Self::Unknown(_, this) => this.as_bytes(), + }; + + bytes = U16::new(data.len() as u16).build_bytes(bytes)?; + bytes = data.build_bytes(bytes)?; + Ok(bytes) + } +} + //----------- OptionCode ----------------------------------------------------- /// An Extended DNS option code. #[derive( Copy, Clone, - Debug, PartialEq, Eq, PartialOrd, @@ -237,11 +344,44 @@ pub struct OptionCode { pub code: U16, } +//--- Associated Constants + +impl OptionCode { + const fn new(code: u16) -> Self { + Self { + code: U16::new(code), + } + } + + /// A DNS cookie (request). + pub const COOKIE: Self = Self::new(10); + + /// An extended DNS error. + pub const EXT_ERROR: Self = Self::new(15); +} + +//--- Formatting + +impl fmt::Debug for OptionCode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match *self { + Self::COOKIE => "OptionCode::COOKIE", + Self::EXT_ERROR => "OptionCode::EXT_ERROR", + _ => { + return f + .debug_tuple("OptionCode") + .field(&self.code.get()) + .finish(); + } + }) + } +} + //----------- UnknownOption -------------------------------------------------- /// Data for an unknown Extended DNS option. #[derive(Debug, AsBytes, BuildBytes, ParseBytesByRef)] -#[repr(C)] +#[repr(transparent)] pub struct UnknownOption { /// The unparsed option data. pub octets: [u8], diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 1be038e45..ebdcc7743 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -174,7 +174,7 @@ impl BuildBytes for RecordData<'_, N> { /// Data for an unknown DNS record type. #[derive(Debug, AsBytes, BuildBytes, ParseBytesByRef)] -#[repr(C)] +#[repr(transparent)] pub struct UnknownRecordData { /// The unparsed option data. pub octets: [u8], From 7bb6d2c3506748014816b76d81da01aa88876e4f Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 6 Jan 2025 18:27:27 +0100 Subject: [PATCH 071/167] [new_base] Add module 'wire' to replace 'zerocopy' This ended up collecting a lot of small changes as I tried to get things to compile. - All the bytes parsing/building traits have been moved to 'wire'. - The 'wire::ints' module replaces 'U16' and 'U32' from 'zerocopy'. - '{Parse,Split}BytesByRef' now support parsing from '&mut'. - Every derive macro is documented under a re-export in 'wire'. The remaining contents of 'new_base::{build, parse}' might get moved into a shared 'message' module at some point. We'll see. --- Cargo.lock | 30 +- Cargo.toml | 6 - macros/Cargo.toml | 2 +- macros/src/lib.rs | 138 ++++++--- src/new_base/build/builder.rs | 23 +- src/new_base/build/mod.rs | 153 +--------- src/new_base/charstr.rs | 11 +- src/new_base/message.rs | 24 +- src/new_base/mod.rs | 1 + src/new_base/name/label.rs | 6 +- src/new_base/name/reversed.rs | 13 +- src/new_base/parse/mod.rs | 346 +---------------------- src/new_base/question.rs | 7 +- src/new_base/record.rs | 35 +-- src/new_base/serial.rs | 4 +- src/new_base/wire/build.rs | 192 +++++++++++++ src/new_base/wire/ints.rs | 282 +++++++++++++++++++ src/new_base/wire/mod.rs | 81 ++++++ src/new_base/wire/parse.rs | 510 ++++++++++++++++++++++++++++++++++ src/new_edns/cookie.rs | 4 +- src/new_edns/ext_err.rs | 4 +- src/new_edns/mod.rs | 10 +- src/new_rdata/basic.rs | 15 +- src/new_rdata/ipv6.rs | 5 +- src/new_rdata/mod.rs | 10 +- 25 files changed, 1253 insertions(+), 659 deletions(-) create mode 100644 src/new_base/wire/build.rs create mode 100644 src/new_base/wire/ints.rs create mode 100644 src/new_base/wire/mod.rs create mode 100644 src/new_base/wire/parse.rs diff --git a/Cargo.lock b/Cargo.lock index d9833efa5..953edfb0c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -279,8 +279,6 @@ dependencies = [ "tracing", "tracing-subscriber", "webpki-roots", - "zerocopy 0.8.13", - "zerocopy-derive 0.8.13", ] [[package]] @@ -809,7 +807,7 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy 0.7.35", + "zerocopy", ] [[package]] @@ -1196,9 +1194,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.79" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -1702,16 +1700,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive 0.7.35", -] - -[[package]] -name = "zerocopy" -version = "0.8.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67914ab451f3bfd2e69e5e9d2ef3858484e7074d63f204fd166ec391b54de21d" -dependencies = [ - "zerocopy-derive 0.8.13", + "zerocopy-derive", ] [[package]] @@ -1725,17 +1714,6 @@ dependencies = [ "syn", ] -[[package]] -name = "zerocopy-derive" -version = "0.8.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7988d73a4303ca289df03316bc490e934accf371af6bc745393cf3c2c5c4f25d" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "zeroize" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index 041d83731..ebe95a4cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,12 +50,6 @@ tokio-stream = { version = "0.1.1", optional = true } tracing = { version = "0.1.40", optional = true } tracing-subscriber = { version = "0.3.18", optional = true, features = ["env-filter"] } -# 'zerocopy' provides simple derives for converting types to and from byte -# representations, along with network-endian integer primitives. These are -# used to define simple elements of DNS messages and their serialization. -zerocopy = "0.8.5" -zerocopy-derive = "0.8.5" - [features] default = ["std", "rand"] diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 263db27af..7060a61eb 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -22,7 +22,7 @@ version = "1.0" [dependencies.syn] version = "2.0" -features = ["visit"] +features = ["full", "visit"] [dependencies.quote] version = "1.0" diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 99d209fff..3fa1bc18b 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -47,7 +47,7 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { ); skeleton.lifetimes.push(param); skeleton.bound = Some( - parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), + parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), ); // Inspect the 'struct' fields. @@ -58,7 +58,7 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { for field in data.fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), + parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), ); } @@ -70,10 +70,10 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { bytes: & #lifetime [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< (Self, & #lifetime [::domain::__core::primitive::u8]), - ::domain::new_base::parse::ParseError, + ::domain::new_base::wire::ParseError, > { #(let (#init_vars, bytes) = - <#tys as ::domain::new_base::parse::SplitBytes<#lifetime>> + <#tys as ::domain::new_base::wire::SplitBytes<#lifetime>> ::split_bytes(bytes)?;)* Ok((#builder, bytes)) } @@ -119,7 +119,7 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { ); skeleton.lifetimes.push(param); skeleton.bound = Some( - parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), + parse_quote!(::domain::new_base::wire::ParseBytes<#lifetime>), ); // Inspect the 'struct' fields. @@ -130,13 +130,13 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { for field in data.sized_fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::parse::SplitBytes<#lifetime>), + parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), ); } if let Some(field) = data.unsized_field() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::parse::ParseBytes<#lifetime>), + parse_quote!(::domain::new_base::wire::ParseBytes<#lifetime>), ); } @@ -147,12 +147,12 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { bytes: & #lifetime [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< Self, - ::domain::new_base::parse::ParseError, + ::domain::new_base::wire::ParseError, > { if bytes.is_empty() { Ok(#builder) } else { - Err(::domain::new_base::parse::ParseError) + Err(::domain::new_base::wire::ParseError) } } }); @@ -170,13 +170,13 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { bytes: & #lifetime [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< Self, - ::domain::new_base::parse::ParseError, + ::domain::new_base::wire::ParseError, > { #(let (#init_vars, bytes) = - <#tys as ::domain::new_base::parse::SplitBytes<#lifetime>> + <#tys as ::domain::new_base::wire::SplitBytes<#lifetime>> ::split_bytes(bytes)?;)* let #unsized_init_var = - <#unsized_ty as ::domain::new_base::parse::ParseBytes<#lifetime>> + <#unsized_ty as ::domain::new_base::wire::ParseBytes<#lifetime>> ::parse_bytes(bytes)?; Ok(#builder) } @@ -217,7 +217,7 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); skeleton.bound = - Some(parse_quote!(::domain::new_base::parse::SplitBytesByRef)); + Some(parse_quote!(::domain::new_base::wire::SplitBytesByRef)); // Inspect the 'struct' fields. let data = Struct::new_as_self(&data.fields); @@ -226,7 +226,7 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { for field in data.fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::parse::SplitBytesByRef), + parse_quote!(::domain::new_base::wire::SplitBytesByRef), ); } @@ -237,7 +237,7 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< (&Self, &[::domain::__core::primitive::u8]), - ::domain::new_base::parse::ParseError, + ::domain::new_base::wire::ParseError, > { Ok(( // SAFETY: 'Self' is a 'struct' with no fields, @@ -260,17 +260,17 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< (&Self, &[::domain::__core::primitive::u8]), - ::domain::new_base::parse::ParseError, + ::domain::new_base::wire::ParseError, > { let start = bytes.as_ptr(); #(let (_, bytes) = - <#tys as ::domain::new_base::parse::SplitBytesByRef> + <#tys as ::domain::new_base::wire::SplitBytesByRef> ::split_bytes_by_ref(bytes)?;)* let (last, rest) = - <#unsized_ty as ::domain::new_base::parse::SplitBytesByRef> + <#unsized_ty as ::domain::new_base::wire::SplitBytesByRef> ::split_bytes_by_ref(bytes)?; let ptr = - <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> ::ptr_with_address(last, start as *const ()); // SAFETY: @@ -287,6 +287,40 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { } }); + // Define 'split_bytes_by_mut()'. + let tys = data.sized_fields().map(|f| &f.ty); + skeleton.contents.stmts.push(parse_quote! { + fn split_bytes_by_mut( + bytes: &mut [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + (&mut Self, &mut [::domain::__core::primitive::u8]), + ::domain::new_base::wire::ParseError, + > { + let start = bytes.as_ptr(); + #(let (_, bytes) = + <#tys as ::domain::new_base::wire::SplitBytesByRef> + ::split_bytes_by_mut(bytes)?;)* + let (last, rest) = + <#unsized_ty as ::domain::new_base::wire::SplitBytesByRef> + ::split_bytes_by_mut(bytes)?; + let ptr = + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> + ::ptr_with_address(last, start as *const ()); + + // SAFETY: + // - The original 'bytes' contained a valid instance of every + // field in 'Self', in succession. + // - Every field implements 'ParseBytesByRef' and so has no + // alignment restriction. + // - 'Self' is unaligned, since every field is unaligned, and + // any explicit alignment modifiers only make it unaligned. + // - 'start' is thus the start of a valid instance of 'Self'. + // - 'ptr' has the same address as 'start' but can be cast to + // 'Self', since it has the right pointer metadata. + Ok((unsafe { &mut *(ptr as *const Self as *mut Self) }, rest)) + } + }); + Ok(skeleton.into_token_stream()) } @@ -322,7 +356,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); skeleton.bound = - Some(parse_quote!(::domain::new_base::parse::ParseBytesByRef)); + Some(parse_quote!(::domain::new_base::wire::ParseBytesByRef)); // Inspect the 'struct' fields. let data = Struct::new_as_self(&data.fields); @@ -331,13 +365,13 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { for field in data.sized_fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::parse::SplitBytesByRef), + parse_quote!(::domain::new_base::wire::SplitBytesByRef), ); } if let Some(field) = data.unsized_field() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::parse::ParseBytesByRef), + parse_quote!(::domain::new_base::wire::ParseBytesByRef), ); } @@ -348,7 +382,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< &Self, - ::domain::new_base::parse::ParseError, + ::domain::new_base::wire::ParseError, > { if bytes.is_empty() { // SAFETY: 'Self' is a 'struct' with no fields, @@ -356,7 +390,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // constructed at any address. Ok(unsafe { &*bytes.as_ptr().cast::() }) } else { - Err(::domain::new_base::parse::ParseError) + Err(::domain::new_base::wire::ParseError) } } }); @@ -381,17 +415,17 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< &Self, - ::domain::new_base::parse::ParseError, + ::domain::new_base::wire::ParseError, > { let start = bytes.as_ptr(); #(let (_, bytes) = - <#tys as ::domain::new_base::parse::SplitBytesByRef> + <#tys as ::domain::new_base::wire::SplitBytesByRef> ::split_bytes_by_ref(bytes)?;)* let last = - <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> ::parse_bytes_by_ref(bytes)?; let ptr = - <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> ::ptr_with_address(last, start as *const ()); // SAFETY: @@ -408,11 +442,45 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { } }); + // Define 'parse_bytes_by_mut()'. + let tys = data.sized_fields().map(|f| &f.ty); + skeleton.contents.stmts.push(parse_quote! { + fn parse_bytes_by_mut( + bytes: &mut [::domain::__core::primitive::u8], + ) -> ::domain::__core::result::Result< + &mut Self, + ::domain::new_base::wire::ParseError, + > { + let start = bytes.as_ptr(); + #(let (_, bytes) = + <#tys as ::domain::new_base::wire::SplitBytesByRef> + ::split_bytes_by_mut(bytes)?;)* + let last = + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> + ::parse_bytes_by_mut(bytes)?; + let ptr = + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> + ::ptr_with_address(last, start as *const ()); + + // SAFETY: + // - The original 'bytes' contained a valid instance of every + // field in 'Self', in succession. + // - Every field implements 'ParseBytesByRef' and so has no + // alignment restriction. + // - 'Self' is unaligned, since every field is unaligned, and + // any explicit alignment modifiers only make it unaligned. + // - 'start' is thus the start of a valid instance of 'Self'. + // - 'ptr' has the same address as 'start' but can be cast to + // 'Self', since it has the right pointer metadata. + Ok(unsafe { &mut *(ptr as *const Self as *mut Self) }) + } + }); + // Define 'ptr_with_address()'. let unsized_member = data.unsized_member(); skeleton.contents.stmts.push(parse_quote! { fn ptr_with_address(&self, addr: *const ()) -> *const Self { - <#unsized_ty as ::domain::new_base::parse::ParseBytesByRef> + <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> ::ptr_with_address(&self.#unsized_member, addr) as *const Self } @@ -451,7 +519,7 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, false); skeleton.bound = - Some(parse_quote!(::domain::new_base::build::BuildBytes)); + Some(parse_quote!(::domain::new_base::wire::BuildBytes)); // Inspect the 'struct' fields. let data = Struct::new_as_self(&data.fields); @@ -463,7 +531,7 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { for field in data.fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::build::BuildBytes), + parse_quote!(::domain::new_base::wire::BuildBytes), ); } @@ -476,9 +544,9 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { mut bytes: & #lifetime mut [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< & #lifetime mut [::domain::__core::primitive::u8], - ::domain::new_base::build::TruncationError, + ::domain::new_base::wire::TruncationError, > { - #(bytes = <#tys as ::domain::new_base::build::BuildBytes> + #(bytes = <#tys as ::domain::new_base::wire::BuildBytes> ::build_bytes(&self.#members, bytes)?;)* Ok(bytes) } @@ -519,13 +587,13 @@ pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); skeleton.bound = - Some(parse_quote!(::domain::new_base::build::AsBytes)); + Some(parse_quote!(::domain::new_base::wire::AsBytes)); // Establish bounds on the fields. for field in data.fields.iter() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::build::AsBytes), + parse_quote!(::domain::new_base::wire::AsBytes), ); } diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 9245b9011..e02da91db 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -6,11 +6,11 @@ use core::{ ptr::{self, NonNull}, }; -use zerocopy::{FromBytes, IntoBytes, SizeError}; - -use crate::new_base::{name::RevName, Header, Message}; - -use super::{BuildBytes, TruncationError}; +use crate::new_base::{ + name::RevName, + wire::{AsBytes, BuildBytes, ParseBytesByRef, TruncationError}, + Header, Message, +}; //----------- Builder -------------------------------------------------------- @@ -82,8 +82,7 @@ impl<'b> Builder<'b> { context: &'b mut BuilderContext, ) -> Self { assert!(buffer.len() >= 12); - let message = Message::mut_from_bytes(buffer) - .map_err(SizeError::from) + let message = Message::parse_bytes_by_mut(buffer) .expect("A 'Message' can fit in 12 bytes"); context.size = 0; context.max_size = message.contents.len(); @@ -156,9 +155,8 @@ impl<'b> Builder<'b> { pub fn message(&self) -> &Message { // SAFETY: All of 'message' can be immutably borrowed by 'self'. let message = unsafe { &*self.message.as_ptr() }; - let message = message.as_bytes(); - Message::ref_from_bytes_with_elems(message, self.commit) - .map_err(SizeError::from) + let message = &message.as_bytes()[..12 + self.commit]; + Message::parse_bytes_by_ref(message) .expect("'message' represents a valid 'Message'") } @@ -170,9 +168,8 @@ impl<'b> Builder<'b> { pub fn cur_message(&self) -> &Message { // SAFETY: All of 'message' can be immutably borrowed by 'self'. let message = unsafe { &*self.message.as_ptr() }; - let message = message.as_bytes(); - Message::ref_from_bytes_with_elems(message, self.context.size) - .map_err(SizeError::from) + let message = &message.as_bytes()[..12 + self.context.size]; + Message::parse_bytes_by_ref(message) .expect("'message' represents a valid 'Message'") } diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 548b2d8fd..2faca3c16 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -1,12 +1,10 @@ //! Building DNS messages in the wire format. -use core::fmt; - -use zerocopy::network_endian::{U16, U32}; - mod builder; pub use builder::{Builder, BuilderContext}; +pub use super::wire::TruncationError; + //----------- Message-aware building traits ---------------------------------- /// Building into a DNS message. @@ -41,150 +39,3 @@ impl BuildIntoMessage for [u8] { Ok(()) } } - -//----------- Low-level building traits -------------------------------------- - -/// Serializing into a byte string. -pub trait BuildBytes { - /// Serialize into a byte string. - /// - /// `self` is serialized into a byte string and written to the given - /// buffer. If the buffer is large enough, the whole object is written - /// and the remaining (unmodified) part of the buffer is returned. - /// - /// if the buffer is too small, a [`TruncationError`] is returned (and - /// parts of the buffer may be modified). - fn build_bytes<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError>; -} - -impl BuildBytes for &T { - fn build_bytes<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - T::build_bytes(*self, bytes) - } -} - -impl BuildBytes for u8 { - fn build_bytes<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - if let Some((elem, rest)) = bytes.split_first_mut() { - *elem = *self; - Ok(rest) - } else { - Err(TruncationError) - } - } -} - -impl BuildBytes for str { - fn build_bytes<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.as_bytes().build_bytes(bytes) - } -} - -impl BuildBytes for U16 { - fn build_bytes<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.as_bytes().build_bytes(bytes) - } -} - -impl BuildBytes for U32 { - fn build_bytes<'b>( - &self, - bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - self.as_bytes().build_bytes(bytes) - } -} - -impl BuildBytes for [T] { - fn build_bytes<'b>( - &self, - mut bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - for elem in self { - bytes = elem.build_bytes(bytes)?; - } - Ok(bytes) - } -} - -impl BuildBytes for [T; N] { - fn build_bytes<'b>( - &self, - mut bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - for elem in self { - bytes = elem.build_bytes(bytes)?; - } - Ok(bytes) - } -} - -/// Interpreting a value as a byte string. -/// -/// # Safety -/// -/// A type `T` can soundly implement [`AsBytes`] if and only if: -/// -/// - It has no padding bytes. -/// - It has no interior mutability. -pub unsafe trait AsBytes { - /// Interpret this value as a sequence of bytes. - /// - /// ## Invariants - /// - /// For the statement `let bytes = this.as_bytes();`, - /// - /// - `bytes.as_ptr() as usize == this as *const _ as usize`. - /// - `bytes.len() == core::mem::size_of_val(this)`. - /// - /// The default implementation automatically satisfies these invariants. - fn as_bytes(&self) -> &[u8] { - // SAFETY: - // - 'Self' has no padding bytes and no interior mutability. - // - Its size in memory is exactly 'size_of_val(self)'. - unsafe { - core::slice::from_raw_parts( - self as *const Self as *const u8, - core::mem::size_of_val(self), - ) - } - } -} - -unsafe impl AsBytes for u8 {} -unsafe impl AsBytes for str {} - -unsafe impl AsBytes for [T] {} -unsafe impl AsBytes for [T; N] {} - -unsafe impl AsBytes for U16 {} -unsafe impl AsBytes for U32 {} - -//----------- TruncationError ------------------------------------------------ - -/// A DNS message did not fit in a buffer. -#[derive(Clone, Debug, PartialEq, Hash)] -pub struct TruncationError; - -//--- Formatting - -impl fmt::Display for TruncationError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("A buffer was too small to fit a DNS message") - } -} diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index 2a82e95fa..979e8be20 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -2,13 +2,12 @@ use core::{fmt, ops::Range}; -use zerocopy::IntoBytes; - use super::{ - build::{self, BuildBytes, BuildIntoMessage, TruncationError}, - parse::{ - ParseBytes, ParseError, ParseFromMessage, SplitBytes, - SplitFromMessage, + build::{self, BuildIntoMessage}, + parse::{ParseFromMessage, SplitFromMessage}, + wire::{ + AsBytes, BuildBytes, ParseBytes, ParseError, SplitBytes, + TruncationError, }, Message, }; diff --git a/src/new_base/message.rs b/src/new_base/message.rs index 3307609bb..9c27d384f 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -2,17 +2,14 @@ use core::fmt; -use zerocopy::network_endian::U16; -use zerocopy_derive::*; - use domain_macros::{AsBytes, *}; +use super::wire::U16; + //----------- Message -------------------------------------------------------- /// A DNS message. -#[derive( - FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned, ParseBytesByRef, -)] +#[derive(AsBytes, BuildBytes, ParseBytesByRef)] #[repr(C, packed)] pub struct Message { /// The message header. @@ -30,11 +27,6 @@ pub struct Message { Clone, Debug, Hash, - FromBytes, - IntoBytes, - KnownLayout, - Immutable, - Unaligned, AsBytes, BuildBytes, ParseBytes, @@ -76,11 +68,6 @@ impl fmt::Display for Header { Clone, Default, Hash, - FromBytes, - IntoBytes, - KnownLayout, - Immutable, - Unaligned, AsBytes, BuildBytes, ParseBytes, @@ -243,11 +230,6 @@ impl fmt::Display for HeaderFlags { PartialEq, Eq, Hash, - FromBytes, - IntoBytes, - KnownLayout, - Immutable, - Unaligned, AsBytes, BuildBytes, ParseBytes, diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 3c2e34068..df8632884 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -32,3 +32,4 @@ pub use serial::Serial; pub mod build; pub mod parse; +pub mod wire; diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index 78ef94008..9cb4d1d85 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -9,7 +9,7 @@ use core::{ use domain_macros::AsBytes; -use crate::new_base::parse::{ParseBytes, ParseError, SplitBytes}; +use crate::new_base::wire::{ParseBytes, ParseError, SplitBytes}; //----------- Label ---------------------------------------------------------- @@ -56,8 +56,8 @@ impl<'a> SplitBytes<'a> for &'a Label { fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { let (&size, rest) = bytes.split_first().ok_or(ParseError)?; if size < 64 && rest.len() >= size as usize { - let (label, rest) = bytes.split_at(1 + size as usize); - // SAFETY: 'label' begins with a valid length octet. + let (label, rest) = rest.split_at(size as usize); + // SAFETY: 'label' is 'size < 64' bytes in size. Ok((unsafe { Label::from_bytes_unchecked(label) }, rest)) } else { Err(ParseError) diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 6fae3c0f2..ba7cdb8c6 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -8,13 +8,12 @@ use core::{ ops::{Deref, Range}, }; -use zerocopy::IntoBytes; - use crate::new_base::{ - build::{self, BuildBytes, BuildIntoMessage, TruncationError}, - parse::{ - ParseBytes, ParseError, ParseFromMessage, SplitBytes, - SplitFromMessage, + build::{self, BuildIntoMessage}, + parse::{ParseFromMessage, SplitFromMessage}, + wire::{ + AsBytes, BuildBytes, ParseBytes, ParseError, SplitBytes, + TruncationError, }, Message, }; @@ -228,7 +227,7 @@ impl RevNameBuf { /// Construct an empty, invalid buffer. fn empty() -> Self { Self { - offset: 0, + offset: 255, buffer: [0; 255], } } diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 32bc3627a..d36dd9543 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -1,11 +1,6 @@ //! Parsing DNS messages from the wire format. -use core::{fmt, ops::Range}; - -use zerocopy::{ - network_endian::{U16, U32}, - FromBytes, IntoBytes, -}; +use core::ops::Range; mod message; pub use message::{MessagePart, ParseMessage, VisitMessagePart}; @@ -16,7 +11,12 @@ pub use question::{ParseQuestion, ParseQuestions, VisitQuestion}; mod record; pub use record::{ParseRecord, ParseRecords, VisitRecord}; -use super::Message; +pub use super::wire::ParseError; + +use super::{ + wire::{AsBytes, ParseBytesByRef, SplitBytesByRef}, + Message, +}; //----------- Message-aware parsing traits ----------------------------------- @@ -68,335 +68,3 @@ impl<'a, T: ?Sized + ParseBytesByRef> ParseFromMessage<'a> for &'a T { T::parse_bytes_by_ref(bytes) } } - -//----------- Low-level parsing traits --------------------------------------- - -/// Parsing from the start of a byte string. -pub trait SplitBytes<'a>: Sized + ParseBytes<'a> { - /// Parse a value of [`Self`] from the start of the byte string. - /// - /// If parsing is successful, the parsed value and the rest of the string - /// are returned. Otherwise, a [`ParseError`] is returned. - fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError>; -} - -/// Parsing from a byte string. -pub trait ParseBytes<'a>: Sized { - /// Parse a value of [`Self`] from the given byte string. - /// - /// If parsing is successful, the parsed value is returned. Otherwise, a - /// [`ParseError`] is returned. - fn parse_bytes(bytes: &'a [u8]) -> Result; -} - -impl<'a, T: ?Sized + SplitBytesByRef> SplitBytes<'a> for &'a T { - fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - T::split_bytes_by_ref(bytes).map_err(|_| ParseError) - } -} - -impl<'a, T: ?Sized + ParseBytesByRef> ParseBytes<'a> for &'a T { - fn parse_bytes(bytes: &'a [u8]) -> Result { - T::parse_bytes_by_ref(bytes).map_err(|_| ParseError) - } -} - -impl<'a> SplitBytes<'a> for u8 { - fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - bytes.split_first().map(|(&f, r)| (f, r)).ok_or(ParseError) - } -} - -impl<'a> ParseBytes<'a> for u8 { - fn parse_bytes(bytes: &'a [u8]) -> Result { - let [result] = bytes else { - return Err(ParseError); - }; - - Ok(*result) - } -} - -impl<'a> SplitBytes<'a> for U16 { - fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - Self::read_from_prefix(bytes).map_err(Into::into) - } -} - -impl<'a> ParseBytes<'a> for U16 { - fn parse_bytes(bytes: &'a [u8]) -> Result { - Self::read_from_bytes(bytes).map_err(Into::into) - } -} - -impl<'a> SplitBytes<'a> for U32 { - fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - Self::read_from_prefix(bytes).map_err(Into::into) - } -} - -impl<'a> ParseBytes<'a> for U32 { - fn parse_bytes(bytes: &'a [u8]) -> Result { - Self::read_from_bytes(bytes).map_err(Into::into) - } -} - -/// Zero-copy parsing from the start of a byte string. -/// -/// This is an extension of [`ParseBytesByRef`] for types which can determine -/// their own length when parsing. It is usually implemented by [`Sized`] -/// types (where the length is just the size of the type), although it can be -/// sometimes implemented by unsized types. -/// -/// # Safety -/// -/// Every implementation of [`SplitBytesByRef`] must satisfy the invariants -/// documented on [`split_bytes_by_ref()`]. An incorrect implementation is -/// considered to cause undefined behaviour. -/// -/// [`split_bytes_by_ref()`]: Self::split_bytes_by_ref() -/// -/// Note that [`ParseBytesByRef`], required by this trait, also has several -/// invariants that need to be considered with care. -pub unsafe trait SplitBytesByRef: ParseBytesByRef { - /// Interpret a byte string as an instance of [`Self`]. - /// - /// The byte string will be validated and re-interpreted as a reference to - /// [`Self`]. The length of [`Self`] will be determined, possibly based - /// on the contents (but not the length!) of the input, and the remaining - /// bytes will be returned. If the input does not begin with a valid - /// instance of [`Self`], a [`ParseError`] is returned. - /// - /// ## Invariants - /// - /// For the statement `let (this, rest) = T::split_bytes_by_ref(bytes)?;`, - /// - /// - `bytes.as_ptr() == this as *const T as *const u8`. - /// - `bytes.len() == core::mem::size_of_val(this) + rest.len()`. - /// - `bytes.as_ptr().offset(size_of_val(this)) == rest.as_ptr()`. - fn split_bytes_by_ref(bytes: &[u8]) - -> Result<(&Self, &[u8]), ParseError>; -} - -/// Zero-copy parsing from a byte string. -/// -/// # Safety -/// -/// Every implementation of [`ParseBytesByRef`] must satisfy the invariants -/// documented on [`parse_bytes_by_ref()`] and [`ptr_with_address()`]. An -/// incorrect implementation is considered to cause undefined behaviour. -/// -/// [`parse_bytes_by_ref()`]: Self::parse_bytes_by_ref() -/// [`ptr_with_address()`]: Self::ptr_with_address() -/// -/// Implementing types must also have no alignment (i.e. a valid instance of -/// [`Self`] can occur at any address). This eliminates the possibility of -/// padding bytes, even when [`Self`] is part of a larger aggregate type. -pub unsafe trait ParseBytesByRef { - /// Interpret a byte string as an instance of [`Self`]. - /// - /// The byte string will be validated and re-interpreted as a reference to - /// [`Self`]. The whole byte string will be used. If the input is not a - /// valid instance of [`Self`], a [`ParseError`] is returned. - /// - /// ## Invariants - /// - /// For the statement `let this: &T = T::parse_bytes_by_ref(bytes)?;`, - /// - /// - `bytes.as_ptr() == this as *const T as *const u8`. - /// - `bytes.len() == core::mem::size_of_val(this)`. - fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError>; - - /// Change the address of a pointer to [`Self`]. - /// - /// When [`Self`] is used as the last field in a type that also implements - /// [`ParseBytesByRef`], it may be dynamically sized, and so a pointer (or - /// reference) to it may include additional metadata. This metadata is - /// included verbatim in any reference/pointer to the containing type. - /// - /// When the containing type implements [`ParseBytesByRef`], it needs to - /// construct a reference/pointer to itself, which includes this metadata. - /// Rust does not currently offer a general way to extract this metadata - /// or pair it with another address, so this function is necessary. The - /// caller can construct a reference to [`Self`], then change its address - /// to point to the containing type, then cast that pointer to the right - /// type. - /// - /// # Implementing - /// - /// Most users will derive [`ParseBytesByRef`] and so don't need to worry - /// about this. For manual implementations: - /// - /// In the future, an adequate default implementation for this function - /// may be provided. Until then, it should be implemented using one of - /// the following expressions: - /// - /// ```ignore - /// fn ptr_with_address( - /// &self, - /// addr: *const (), - /// ) -> *const Self { - /// // If 'Self' is Sized: - /// addr.cast::() - /// - /// // If 'Self' is an aggregate whose last field is 'last': - /// self.last.ptr_with_address(addr) as *const Self - /// } - /// ``` - /// - /// # Invariants - /// - /// For the statement `let result = Self::ptr_with_address(ptr, addr);`: - /// - /// - `result as usize == addr as usize`. - /// - `core::ptr::metadata(result) == core::ptr::metadata(ptr)`. - fn ptr_with_address(&self, addr: *const ()) -> *const Self; -} - -unsafe impl SplitBytesByRef for u8 { - fn split_bytes_by_ref( - bytes: &[u8], - ) -> Result<(&Self, &[u8]), ParseError> { - bytes.split_first().ok_or(ParseError) - } -} - -unsafe impl ParseBytesByRef for u8 { - fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { - let [result] = bytes else { - return Err(ParseError); - }; - - Ok(result) - } - - fn ptr_with_address(&self, addr: *const ()) -> *const Self { - addr.cast() - } -} - -unsafe impl SplitBytesByRef for U16 { - fn split_bytes_by_ref( - bytes: &[u8], - ) -> Result<(&Self, &[u8]), ParseError> { - Self::ref_from_prefix(bytes).map_err(Into::into) - } -} - -unsafe impl ParseBytesByRef for U16 { - fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { - Self::ref_from_bytes(bytes).map_err(Into::into) - } - - fn ptr_with_address(&self, addr: *const ()) -> *const Self { - addr.cast() - } -} - -unsafe impl SplitBytesByRef for U32 { - fn split_bytes_by_ref( - bytes: &[u8], - ) -> Result<(&Self, &[u8]), ParseError> { - Self::ref_from_prefix(bytes).map_err(Into::into) - } -} - -unsafe impl ParseBytesByRef for U32 { - fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { - Self::ref_from_bytes(bytes).map_err(Into::into) - } - - fn ptr_with_address(&self, addr: *const ()) -> *const Self { - addr.cast() - } -} - -unsafe impl ParseBytesByRef for [u8] { - fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { - Ok(bytes) - } - - fn ptr_with_address(&self, addr: *const ()) -> *const Self { - core::ptr::slice_from_raw_parts(addr.cast(), self.len()) - } -} - -unsafe impl ParseBytesByRef for str { - fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { - core::str::from_utf8(bytes).map_err(|_| ParseError) - } - - fn ptr_with_address(&self, addr: *const ()) -> *const Self { - // NOTE: The Rust Reference indicates that 'str' has the same layout - // as '[u8]' [1]. This is also the most natural layout for it. Since - // there's no way to construct a '*const str' from raw parts, we will - // just construct a raw slice and transmute it. - // - // [1]: https://doc.rust-lang.org/reference/type-layout.html#str-layout - - self.as_bytes().ptr_with_address(addr) as *const Self - } -} - -unsafe impl SplitBytesByRef for [T; N] { - fn split_bytes_by_ref( - mut bytes: &[u8], - ) -> Result<(&Self, &[u8]), ParseError> { - let start = bytes.as_ptr(); - for _ in 0..N { - (_, bytes) = T::split_bytes_by_ref(bytes)?; - } - - // SAFETY: - // - 'T::split_bytes_by_ref()' was called 'N' times on successive - // positions, thus the original 'bytes' starts with 'N' instances - // of 'T' (even if 'T' is a ZST and so all instances overlap). - // - 'N' consecutive 'T's have the same layout as '[T; N]'. - // - Thus it is safe to cast 'start' to '[T; N]'. - // - The referenced data has the same lifetime as the output. - Ok((unsafe { &*start.cast::<[T; N]>() }, bytes)) - } -} - -unsafe impl ParseBytesByRef for [T; N] { - fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { - let (this, rest) = Self::split_bytes_by_ref(bytes)?; - if rest.is_empty() { - Ok(this) - } else { - Err(ParseError) - } - } - - fn ptr_with_address(&self, addr: *const ()) -> *const Self { - addr.cast() - } -} - -//----------- ParseError ----------------------------------------------------- - -/// A DNS message parsing error. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct ParseError; - -//--- Formatting - -impl fmt::Display for ParseError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("DNS data could not be parsed from the wire format") - } -} - -//--- Conversion from 'zerocopy' errors - -impl From> for ParseError { - fn from(_: zerocopy::ConvertError) -> Self { - Self - } -} - -impl From> for ParseError { - fn from(_: zerocopy::SizeError) -> Self { - Self - } -} diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 4e93951aa..0dad0910a 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -2,14 +2,13 @@ use core::ops::Range; -use zerocopy::network_endian::U16; - use domain_macros::*; use super::{ - build::{self, AsBytes, BuildIntoMessage, TruncationError}, + build::{self, BuildIntoMessage}, name::RevNameBuf, - parse::{ParseError, ParseFromMessage, SplitFromMessage}, + parse::{ParseFromMessage, SplitFromMessage}, + wire::{AsBytes, ParseError, TruncationError, U16}, Message, }; diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 391b95dee..2d84e0934 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -5,19 +5,13 @@ use core::{ ops::{Deref, Range}, }; -use zerocopy::{ - network_endian::{U16, U32}, - FromBytes, IntoBytes, -}; - -use domain_macros::*; - use super::{ - build::{self, AsBytes, BuildBytes, BuildIntoMessage, TruncationError}, + build::{self, BuildIntoMessage}, name::RevNameBuf, - parse::{ - ParseBytes, ParseBytesByRef, ParseError, ParseFromMessage, - SplitBytes, SplitFromMessage, + parse::{ParseFromMessage, SplitFromMessage}, + wire::{ + AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, + SplitBytes, SplitBytesByRef, TruncationError, U16, U32, }, Message, }; @@ -160,9 +154,13 @@ where let (rtype, rest) = RType::split_bytes(rest)?; let (rclass, rest) = RClass::split_bytes(rest)?; let (ttl, rest) = TTL::split_bytes(rest)?; - let (size, rest) = U16::read_from_prefix(rest)?; + let (size, rest) = U16::split_bytes(rest)?; let size: usize = size.get().into(); - let (rdata, rest) = <[u8]>::ref_from_prefix_with_elems(rest, size)?; + if rest.len() < size { + return Err(ParseError); + } + + let (rdata, rest) = rest.split_at(size); let rdata = D::parse_record_data_bytes(rdata, rtype)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) @@ -179,10 +177,13 @@ where let (rtype, rest) = RType::split_bytes(rest)?; let (rclass, rest) = RClass::split_bytes(rest)?; let (ttl, rest) = TTL::split_bytes(rest)?; - let (size, rest) = U16::read_from_prefix(rest)?; + let (size, rest) = U16::split_bytes(rest)?; let size: usize = size.get().into(); - let rdata = <[u8]>::ref_from_bytes_with_elems(rest, size)?; - let rdata = D::parse_record_data_bytes(rdata, rtype)?; + if rest.len() != size { + return Err(ParseError); + } + + let rdata = D::parse_record_data_bytes(rest, rtype)?; Ok(Self::new(rname, rtype, rclass, ttl, rdata)) } @@ -205,7 +206,7 @@ where bytes = self.ttl.as_bytes().build_bytes(bytes)?; let (size, bytes) = - ::mut_from_prefix(bytes).map_err(|_| TruncationError)?; + U16::split_bytes_by_mut(bytes).map_err(|_| TruncationError)?; let bytes_len = bytes.len(); let rest = self.rdata.build_bytes(bytes)?; diff --git a/src/new_base/serial.rs b/src/new_base/serial.rs index 4258c4b22..af0e4a1a1 100644 --- a/src/new_base/serial.rs +++ b/src/new_base/serial.rs @@ -8,10 +8,10 @@ use core::{ ops::{Add, AddAssign}, }; -use zerocopy::network_endian::U32; - use domain_macros::*; +use super::wire::U32; + //----------- Serial --------------------------------------------------------- /// A serial number. diff --git a/src/new_base/wire/build.rs b/src/new_base/wire/build.rs new file mode 100644 index 000000000..1b67a4d40 --- /dev/null +++ b/src/new_base/wire/build.rs @@ -0,0 +1,192 @@ +//! Building data in the basic network format. + +use core::fmt; + +//----------- BuildBytes ----------------------------------------------------- + +/// Serializing into a byte string. +pub trait BuildBytes { + /// Serialize into a byte string. + /// + /// `self` is serialized into a byte string and written to the given + /// buffer. If the buffer is large enough, the whole object is written + /// and the remaining (unmodified) part of the buffer is returned. + /// + /// if the buffer is too small, a [`TruncationError`] is returned (and + /// parts of the buffer may be modified). + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError>; +} + +impl BuildBytes for &T { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + T::build_bytes(*self, bytes) + } +} + +impl BuildBytes for u8 { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + if let Some((elem, rest)) = bytes.split_first_mut() { + *elem = *self; + Ok(rest) + } else { + Err(TruncationError) + } + } +} + +impl BuildBytes for str { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.as_bytes().build_bytes(bytes) + } +} + +impl BuildBytes for [T] { + fn build_bytes<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + for elem in self { + bytes = elem.build_bytes(bytes)?; + } + Ok(bytes) + } +} + +impl BuildBytes for [T; N] { + fn build_bytes<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + for elem in self { + bytes = elem.build_bytes(bytes)?; + } + Ok(bytes) + } +} + +/// Deriving [`BuildBytes`] automatically. +/// +/// [`BuildBytes`] can be derived on `struct`s (not `enum`s or `union`s). The +/// generated implementation will call [`build_bytes()`] with each field, in +/// the order they are declared. The trait implementation will be bounded by +/// the type of every field implementing [`BuildBytes`]. +/// +/// [`build_bytes()`]: BuildBytes::build_bytes() +/// +/// Here's a simple example: +/// +/// ``` +/// # use domain::new_base::wire::{BuildBytes, U32, TruncationError}; +/// struct Foo { +/// a: U32, +/// b: Bar, +/// } +/// +/// # struct Bar { data: T } +/// +/// // The generated impl with 'derive(BuildBytes)': +/// impl BuildBytes for Foo +/// where Bar: BuildBytes { +/// fn build_bytes<'bytes>( +/// &self, +/// mut bytes: &'bytes mut [u8], +/// ) -> Result<&'bytes mut [u8], TruncationError> { +/// bytes = self.a.build_bytes(bytes)?; +/// bytes = self.b.build_bytes(bytes)?; +/// Ok(bytes) +/// } +/// } +/// ``` +pub use domain_macros::BuildBytes; + +//----------- AsBytes -------------------------------------------------------- + +/// Interpreting a value as a byte string. +/// +/// # Safety +/// +/// A type `T` can soundly implement [`AsBytes`] if and only if: +/// +/// - It has no padding bytes. +/// - It has no interior mutability. +pub unsafe trait AsBytes { + /// Interpret this value as a sequence of bytes. + /// + /// ## Invariants + /// + /// For the statement `let bytes = this.as_bytes();`, + /// + /// - `bytes.as_ptr() as usize == this as *const _ as usize`. + /// - `bytes.len() == core::mem::size_of_val(this)`. + /// + /// The default implementation automatically satisfies these invariants. + fn as_bytes(&self) -> &[u8] { + // SAFETY: + // - 'Self' has no padding bytes and no interior mutability. + // - Its size in memory is exactly 'size_of_val(self)'. + unsafe { + core::slice::from_raw_parts( + self as *const Self as *const u8, + core::mem::size_of_val(self), + ) + } + } +} + +unsafe impl AsBytes for u8 {} +unsafe impl AsBytes for str {} + +unsafe impl AsBytes for [T] {} +unsafe impl AsBytes for [T; N] {} + +/// Deriving [`AsBytes`] automatically. +/// +/// [`AsBytes`] can be derived on `struct`s (not `enum`s or `union`s), where a +/// fixed memory layout (`repr(C)` or `repr(transparent)`) is used. Every +/// field must implement [`AsBytes`]. +/// +/// Here's a simple example: +/// +/// ``` +/// # use domain::new_base::wire::{AsBytes, U32}; +/// #[repr(C)] +/// struct Foo { +/// a: U32, +/// b: Bar, +/// } +/// +/// # struct Bar { data: T } +/// +/// // The generated impl with 'derive(AsBytes)': +/// unsafe impl AsBytes for Foo +/// where Bar: AsBytes { +/// // The default implementation of 'as_bytes()' works. +/// } +/// ``` +pub use domain_macros::AsBytes; + +//----------- TruncationError ------------------------------------------------ + +/// A DNS message did not fit in a buffer. +#[derive(Clone, Debug, PartialEq, Hash)] +pub struct TruncationError; + +//--- Formatting + +impl fmt::Display for TruncationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("A buffer was too small to fit a DNS message") + } +} diff --git a/src/new_base/wire/ints.rs b/src/new_base/wire/ints.rs new file mode 100644 index 000000000..3d11f45e4 --- /dev/null +++ b/src/new_base/wire/ints.rs @@ -0,0 +1,282 @@ +//! Integer primitives for the DNS wire format. + +use core::{ + cmp::Ordering, + fmt, + ops::{ + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, + BitXorAssign, Not, Sub, SubAssign, + }, +}; + +use domain_macros::*; + +use super::{ + ParseBytes, ParseBytesByRef, ParseError, SplitBytes, SplitBytesByRef, +}; + +//----------- define_int ----------------------------------------------------- + +/// Define a network endianness integer primitive. +macro_rules! define_int { + { $( + $(#[$docs:meta])* + $name:ident([u8; $size:literal]) = $base:ident; + )* } => { $( + $(#[$docs])* + #[derive( + Copy, + Clone, + Default, + PartialEq, + Eq, + Hash, + AsBytes, + BuildBytes, + ParseBytesByRef, + SplitBytesByRef, + )] + #[repr(transparent)] + pub struct $name([u8; $size]); + + //--- Conversion to and from integer primitive types + + impl $name { + /// Convert an integer to network endianness. + pub const fn new(value: $base) -> Self { + Self(value.to_be_bytes()) + } + + /// Convert an integer from network endianness. + pub const fn get(self) -> $base { + <$base>::from_be_bytes(self.0) + } + } + + impl From<$base> for $name { + fn from(value: $base) -> Self { + Self::new(value) + } + } + + impl From<$name> for $base { + fn from(value: $name) -> Self { + value.get() + } + } + + //--- Parsing from bytes + + impl<'b> ParseBytes<'b> for $name { + fn parse_bytes(bytes: &'b [u8]) -> Result { + Self::parse_bytes_by_ref(bytes).copied() + } + } + + impl<'b> SplitBytes<'b> for $name { + fn split_bytes( + bytes: &'b [u8], + ) -> Result<(Self, &'b [u8]), ParseError> { + Self::split_bytes_by_ref(bytes) + .map(|(&this, rest)| (this, rest)) + } + } + + //--- Comparison + + impl PartialEq<$base> for $name { + fn eq(&self, other: &$base) -> bool { + self.get() == *other + } + } + + impl PartialOrd<$base> for $name { + fn partial_cmp(&self, other: &$base) -> Option { + self.get().partial_cmp(other) + } + } + + impl PartialOrd for $name { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl Ord for $name { + fn cmp(&self, other: &Self) -> Ordering { + self.get().cmp(&other.get()) + } + } + + //--- Formatting + + impl fmt::Debug for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple(stringify!($name)).field(&self.get()).finish() + } + } + + //--- Arithmetic + + impl Add for $name { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self::new(self.get() + rhs.get()) + } + } + + impl AddAssign for $name { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } + } + + impl Add<$base> for $name { + type Output = Self; + + fn add(self, rhs: $base) -> Self::Output { + Self::new(self.get() + rhs) + } + } + + impl AddAssign<$base> for $name { + fn add_assign(&mut self, rhs: $base) { + *self = *self + rhs; + } + } + + impl Sub for $name { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self::new(self.get() - rhs.get()) + } + } + + impl SubAssign for $name { + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } + } + + impl Sub<$base> for $name { + type Output = Self; + + fn sub(self, rhs: $base) -> Self::Output { + Self::new(self.get() - rhs) + } + } + + impl SubAssign<$base> for $name { + fn sub_assign(&mut self, rhs: $base) { + *self = *self - rhs; + } + } + + impl Not for $name { + type Output = Self; + + fn not(self) -> Self::Output { + Self::new(!self.get()) + } + } + + //--- Bitwise operations + + impl BitAnd for $name { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + Self::new(self.get() & rhs.get()) + } + } + + impl BitAndAssign for $name { + fn bitand_assign(&mut self, rhs: Self) { + *self = *self & rhs; + } + } + + impl BitAnd<$base> for $name { + type Output = Self; + + fn bitand(self, rhs: $base) -> Self::Output { + Self::new(self.get() & rhs) + } + } + + impl BitAndAssign<$base> for $name { + fn bitand_assign(&mut self, rhs: $base) { + *self = *self & rhs; + } + } + + impl BitOr for $name { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + Self::new(self.get() | rhs.get()) + } + } + + impl BitOrAssign for $name { + fn bitor_assign(&mut self, rhs: Self) { + *self = *self | rhs; + } + } + + impl BitOr<$base> for $name { + type Output = Self; + + fn bitor(self, rhs: $base) -> Self::Output { + Self::new(self.get() | rhs) + } + } + + impl BitOrAssign<$base> for $name { + fn bitor_assign(&mut self, rhs: $base) { + *self = *self | rhs; + } + } + + impl BitXor for $name { + type Output = Self; + + fn bitxor(self, rhs: Self) -> Self::Output { + Self::new(self.get() ^ rhs.get()) + } + } + + impl BitXorAssign for $name { + fn bitxor_assign(&mut self, rhs: Self) { + *self = *self ^ rhs; + } + } + + impl BitXor<$base> for $name { + type Output = Self; + + fn bitxor(self, rhs: $base) -> Self::Output { + Self::new(self.get() ^ rhs) + } + } + + impl BitXorAssign<$base> for $name { + fn bitxor_assign(&mut self, rhs: $base) { + *self = *self ^ rhs; + } + } + )* }; +} + +define_int! { + /// An unsigned 16-bit integer in network endianness. + U16([u8; 2]) = u16; + + /// An unsigned 32-bit integer in network endianness. + U32([u8; 4]) = u32; + + /// An unsigned 64-bit integer in network endianness. + U64([u8; 8]) = u64; +} diff --git a/src/new_base/wire/mod.rs b/src/new_base/wire/mod.rs new file mode 100644 index 000000000..4d5be5c25 --- /dev/null +++ b/src/new_base/wire/mod.rs @@ -0,0 +1,81 @@ +//! The basic wire format of network protocols. +//! +//! This is a low-level module providing simple and efficient mechanisms to +//! parse data from and build data into byte sequences. It takes inspiration +//! from the [zerocopy] crate, but 1) is significantly simpler, 2) has simple +//! requirements for its `derive` macros, and 3) supports parsing out-of-place +//! (i.e. non-zero-copy). +//! +//! [zerocopy]: https://github.com/google/zerocopy +//! +//! # Design +//! +//! When a type is defined to represent a component of a network packet, its +//! internal structure should match the structure of its wire format. Here's +//! an example of a question in a DNS record: +//! +//! ``` +//! # use domain::new_base::{QType, QClass, wire::*}; +//! #[derive(BuildBytes, ParseBytes, SplitBytes)] +//! pub struct Question { +//! /// The domain name being requested. +//! pub qname: N, +//! +//! /// The type of the requested records. +//! pub qtype: QType, +//! +//! /// The class of the requested records. +//! pub qclass: QClass, +//! } +//! ``` +//! +//! This exactly matches the structure of a question on the wire -- the QNAME, +//! the QTYPE, and the QCLASS. This allows the definition of the type to also +//! specify the wire format concisely. +//! +//! Now, this type can be read from and written to bytes very easily: +//! +//! ``` +//! # use domain::new_base::{Question, name::RevNameBuf, wire::*}; +//! // { qname: "org.", qtype: A, qclass: IN } +//! let bytes = [3, 111, 114, 103, 0, 0, 1, 0, 1]; +//! let question = Question::::parse_bytes(&bytes).unwrap(); +//! let mut duplicate = [0u8; 9]; +//! let rest = question.build_bytes(&mut duplicate).unwrap(); +//! assert_eq!(*rest, []); +//! assert_eq!(bytes, duplicate); +//! ``` +//! +//! There are three important traits to consider: +//! +//! - [`ParseBytes`]: For interpreting an entire byte string as an instance of +//! the target type. +//! +//! - [`SplitBytes`]: For interpreting _the start_ of a byte string as an +//! instance of the target type. +//! +//! - [`BuildBytes`]: For serializing an object and writing it to the _start_ +//! of a byte string. +//! +//! These operate by value, and copy (some) data from the input. However, +//! there are also zero-copy versions of these traits, which are more +//! efficient (but not always applicable): +//! +//! - [`ParseBytesByRef`]: Like [`ParseBytes`], but transmutes the byte string +//! into an instance of the target type in place. +//! +//! - [`SplitBytesByRef`]: Like [`SplitBytes`], but transmutes the byte string +//! into an instance of the target type in place. +//! +//! - [`AsBytes`]: Allows interpreting an object as a byte string in place. + +mod build; +pub use build::{AsBytes, BuildBytes, TruncationError}; + +mod parse; +pub use parse::{ + ParseBytes, ParseBytesByRef, ParseError, SplitBytes, SplitBytesByRef, +}; + +mod ints; +pub use ints::{U16, U32, U64}; diff --git a/src/new_base/wire/parse.rs b/src/new_base/wire/parse.rs new file mode 100644 index 000000000..3ee5d44a1 --- /dev/null +++ b/src/new_base/wire/parse.rs @@ -0,0 +1,510 @@ +//! Parsing bytes in the basic network format. + +use core::fmt; + +//----------- ParseBytes ----------------------------------------------------- + +/// Parsing from a byte string. +pub trait ParseBytes<'a>: Sized { + /// Parse a value of [`Self`] from the given byte string. + /// + /// If parsing is successful, the parsed value is returned. Otherwise, a + /// [`ParseError`] is returned. + fn parse_bytes(bytes: &'a [u8]) -> Result; +} + +impl<'a> ParseBytes<'a> for u8 { + fn parse_bytes(bytes: &'a [u8]) -> Result { + let [result] = bytes else { + return Err(ParseError); + }; + + Ok(*result) + } +} + +impl<'a, T: ?Sized + ParseBytesByRef> ParseBytes<'a> for &'a T { + fn parse_bytes(bytes: &'a [u8]) -> Result { + T::parse_bytes_by_ref(bytes).map_err(|_| ParseError) + } +} + +/// Deriving [`ParseBytes`] automatically. +/// +/// [`ParseBytes`] can be derived on `struct`s (not `enum`s or `union`s). All +/// fields except the last must implement [`SplitBytes`], while the last field +/// only needs to implement [`ParseBytes`]. +/// +/// Here's a simple example: +/// +/// ``` +/// # use domain::new_base::wire::{ParseBytes, SplitBytes, U32, ParseError}; +/// struct Foo { +/// a: U32, +/// b: Bar, +/// } +/// +/// # struct Bar { data: T } +/// +/// // The generated impl with 'derive(ParseBytes)': +/// impl<'bytes, T> ParseBytes<'bytes> for Foo +/// where +/// U32: SplitBytes<'bytes>, +/// Bar: ParseBytes<'bytes>, +/// { +/// fn parse_bytes( +/// bytes: &'bytes [u8], +/// ) -> Result { +/// let (field_a, bytes) = U32::split_bytes(bytes)?; +/// let field_b = >::parse_bytes(bytes)?; +/// Ok(Self { a: field_a, b: field_b }) +/// } +/// } +/// ``` +pub use domain_macros::ParseBytes; + +//----------- SplitBytes ----------------------------------------------------- + +/// Parsing from the start of a byte string. +pub trait SplitBytes<'a>: Sized + ParseBytes<'a> { + /// Parse a value of [`Self`] from the start of the byte string. + /// + /// If parsing is successful, the parsed value and the rest of the string + /// are returned. Otherwise, a [`ParseError`] is returned. + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError>; +} + +impl<'a, T: ?Sized + SplitBytesByRef> SplitBytes<'a> for &'a T { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + T::split_bytes_by_ref(bytes).map_err(|_| ParseError) + } +} + +impl<'a> SplitBytes<'a> for u8 { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + bytes.split_first().map(|(&f, r)| (f, r)).ok_or(ParseError) + } +} + +/// Deriving [`SplitBytes`] automatically. +/// +/// [`SplitBytes`] can be derived on `struct`s (not `enum`s or `union`s). All +/// fields except the last must implement [`SplitBytes`], while the last field +/// only needs to implement [`SplitBytes`]. +/// +/// Here's a simple example: +/// +/// ``` +/// # use domain::new_base::wire::{ParseBytes, SplitBytes, U32, ParseError}; +/// #[derive(ParseBytes)] +/// struct Foo { +/// a: U32, +/// b: Bar, +/// } +/// +/// # struct Bar { data: T } +/// +/// // The generated impl with 'derive(SplitBytes)': +/// impl<'bytes, T> SplitBytes<'bytes> for Foo +/// where +/// U32: SplitBytes<'bytes>, +/// Bar: SplitBytes<'bytes>, +/// { +/// fn split_bytes( +/// bytes: &'bytes [u8], +/// ) -> Result<(Self, &'bytes [u8]), ParseError> { +/// let (field_a, bytes) = U32::split_bytes(bytes)?; +/// let (field_b, bytes) = >::split_bytes(bytes)?; +/// Ok((Self { a: field_a, b: field_b }, bytes)) +/// } +/// } +/// ``` +pub use domain_macros::SplitBytes; + +//----------- ParseBytesByRef ------------------------------------------------ + +/// Zero-copy parsing from a byte string. +/// +/// # Safety +/// +/// Every implementation of [`ParseBytesByRef`] must satisfy the invariants +/// documented on [`parse_bytes_by_ref()`] and [`ptr_with_address()`]. An +/// incorrect implementation is considered to cause undefined behaviour. +/// +/// [`parse_bytes_by_ref()`]: Self::parse_bytes_by_ref() +/// [`ptr_with_address()`]: Self::ptr_with_address() +/// +/// Implementing types must also have no alignment (i.e. a valid instance of +/// [`Self`] can occur at any address). This eliminates the possibility of +/// padding bytes, even when [`Self`] is part of a larger aggregate type. +pub unsafe trait ParseBytesByRef { + /// Interpret a byte string as an instance of [`Self`]. + /// + /// The byte string will be validated and re-interpreted as a reference to + /// [`Self`]. The whole byte string will be used. If the input is not a + /// valid instance of [`Self`], a [`ParseError`] is returned. + /// + /// ## Invariants + /// + /// For the statement `let this: &T = T::parse_bytes_by_ref(bytes)?;`, + /// + /// - `bytes.as_ptr() == this as *const T as *const u8`. + /// - `bytes.len() == core::mem::size_of_val(this)`. + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError>; + + /// Interpret a byte string as an instance of [`Self`], mutably. + /// + /// The byte string will be validated and re-interpreted as a reference to + /// [`Self`]. The whole byte string will be used. If the input is not a + /// valid instance of [`Self`], a [`ParseError`] is returned. + /// + /// ## Invariants + /// + /// For the statement `let this: &mut T = T::parse_bytes_by_mut(bytes)?;`, + /// + /// - `bytes.as_ptr() == this as *const T as *const u8`. + /// - `bytes.len() == core::mem::size_of_val(this)`. + fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError>; + + /// Change the address of a pointer to [`Self`]. + /// + /// When [`Self`] is used as the last field in a type that also implements + /// [`ParseBytesByRef`], it may be dynamically sized, and so a pointer (or + /// reference) to it may include additional metadata. This metadata is + /// included verbatim in any reference/pointer to the containing type. + /// + /// When the containing type implements [`ParseBytesByRef`], it needs to + /// construct a reference/pointer to itself, which includes this metadata. + /// Rust does not currently offer a general way to extract this metadata + /// or pair it with another address, so this function is necessary. The + /// caller can construct a reference to [`Self`], then change its address + /// to point to the containing type, then cast that pointer to the right + /// type. + /// + /// # Implementing + /// + /// Most users will derive [`ParseBytesByRef`] and so don't need to worry + /// about this. For manual implementations: + /// + /// In the future, an adequate default implementation for this function + /// may be provided. Until then, it should be implemented using one of + /// the following expressions: + /// + /// ```ignore + /// fn ptr_with_address( + /// &self, + /// addr: *const (), + /// ) -> *const Self { + /// // If 'Self' is Sized: + /// addr.cast::() + /// + /// // If 'Self' is an aggregate whose last field is 'last': + /// self.last.ptr_with_address(addr) as *const Self + /// } + /// ``` + /// + /// # Invariants + /// + /// For the statement `let result = Self::ptr_with_address(ptr, addr);`: + /// + /// - `result as usize == addr as usize`. + /// - `core::ptr::metadata(result) == core::ptr::metadata(ptr)`. + fn ptr_with_address(&self, addr: *const ()) -> *const Self; +} + +unsafe impl ParseBytesByRef for u8 { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + if let [result] = bytes { + Ok(result) + } else { + Err(ParseError) + } + } + + fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError> { + if let [result] = bytes { + Ok(result) + } else { + Err(ParseError) + } + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + addr.cast() + } +} + +unsafe impl ParseBytesByRef for [u8] { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + Ok(bytes) + } + + fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError> { + Ok(bytes) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + core::ptr::slice_from_raw_parts(addr.cast(), self.len()) + } +} + +unsafe impl ParseBytesByRef for str { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + core::str::from_utf8(bytes).map_err(|_| ParseError) + } + + fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError> { + core::str::from_utf8_mut(bytes).map_err(|_| ParseError) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + // NOTE: The Rust Reference indicates that 'str' has the same layout + // as '[u8]' [1]. This is also the most natural layout for it. Since + // there's no way to construct a '*const str' from raw parts, we will + // just construct a raw slice and transmute it. + // + // [1]: https://doc.rust-lang.org/reference/type-layout.html#str-layout + + self.as_bytes().ptr_with_address(addr) as *const Self + } +} + +unsafe impl ParseBytesByRef for [T; N] { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + let (this, rest) = Self::split_bytes_by_ref(bytes)?; + if rest.is_empty() { + Ok(this) + } else { + Err(ParseError) + } + } + + fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError> { + let (this, rest) = Self::split_bytes_by_mut(bytes)?; + if rest.is_empty() { + Ok(this) + } else { + Err(ParseError) + } + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + addr.cast() + } +} + +/// Deriving [`ParseBytesByRef`] automatically. +/// +/// [`ParseBytesByRef`] can be derived on `struct`s (not `enum`s or `union`s), +/// where a fixed memory layout (`repr(C)` or `repr(transparent)`) is used. +/// All fields except the last must implement [`SplitBytesByRef`], while the +/// last field only needs to implement [`ParseBytesByRef`]. +/// +/// Here's a simple example: +/// +/// ``` +/// # use domain::new_base::wire::{ParseBytesByRef, SplitBytesByRef, U32, ParseError}; +/// #[repr(C)] +/// struct Foo { +/// a: U32, +/// b: Bar, +/// } +/// +/// # struct Bar { data: T } +/// +/// // The generated impl with 'derive(ParseBytesByRef)': +/// unsafe impl ParseBytesByRef for Foo +/// where Bar: ParseBytesByRef { +/// fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { +/// let addr = bytes.as_ptr(); +/// let (_, bytes) = U32::split_bytes_by_ref(bytes)?; +/// let last = >::parse_bytes_by_ref(bytes)?; +/// let this = last.ptr_with_address(addr as *const ()); +/// Ok(unsafe { &*(this as *const Self) }) +/// } +/// +/// fn parse_bytes_by_mut( +/// bytes: &mut [u8], +/// ) -> Result<&mut Self, ParseError> { +/// let addr = bytes.as_ptr(); +/// let (_, bytes) = U32::split_bytes_by_ref(bytes)?; +/// let last = >::parse_bytes_by_ref(bytes)?; +/// let this = last.ptr_with_address(addr as *const ()); +/// Ok(unsafe { &mut *(this as *const Self as *mut Self) }) +/// } +/// +/// fn ptr_with_address(&self, addr: *const ()) -> *const Self { +/// self.b.ptr_with_address(addr) as *const Self +/// } +/// } +/// ``` +pub use domain_macros::ParseBytesByRef; + +//----------- SplitBytesByRef ------------------------------------------------ + +/// Zero-copy parsing from the start of a byte string. +/// +/// This is an extension of [`ParseBytesByRef`] for types which can determine +/// their own length when parsing. It is usually implemented by [`Sized`] +/// types (where the length is just the size of the type), although it can be +/// sometimes implemented by unsized types. +/// +/// # Safety +/// +/// Every implementation of [`SplitBytesByRef`] must satisfy the invariants +/// documented on [`split_bytes_by_ref()`]. An incorrect implementation is +/// considered to cause undefined behaviour. +/// +/// [`split_bytes_by_ref()`]: Self::split_bytes_by_ref() +/// +/// Note that [`ParseBytesByRef`], required by this trait, also has several +/// invariants that need to be considered with care. +pub unsafe trait SplitBytesByRef: ParseBytesByRef { + /// Interpret a byte string as an instance of [`Self`], mutably. + /// + /// The byte string will be validated and re-interpreted as a reference to + /// [`Self`]. The length of [`Self`] will be determined, possibly based + /// on the contents (but not the length!) of the input, and the remaining + /// bytes will be returned. If the input does not begin with a valid + /// instance of [`Self`], a [`ParseError`] is returned. + /// + /// ## Invariants + /// + /// For the statement `let (this, rest) = T::split_bytes_by_ref(bytes)?;`, + /// + /// - `bytes.as_ptr() == this as *const T as *const u8`. + /// - `bytes.len() == core::mem::size_of_val(this) + rest.len()`. + /// - `bytes.as_ptr().offset(size_of_val(this)) == rest.as_ptr()`. + fn split_bytes_by_ref(bytes: &[u8]) + -> Result<(&Self, &[u8]), ParseError>; + + /// Interpret a byte string as an instance of [`Self`]. + /// + /// The byte string will be validated and re-interpreted as a reference to + /// [`Self`]. The length of [`Self`] will be determined, possibly based + /// on the contents (but not the length!) of the input, and the remaining + /// bytes will be returned. If the input does not begin with a valid + /// instance of [`Self`], a [`ParseError`] is returned. + /// + /// ## Invariants + /// + /// For the statement `let (this, rest) = T::split_bytes_by_mut(bytes)?;`, + /// + /// - `bytes.as_ptr() == this as *const T as *const u8`. + /// - `bytes.len() == core::mem::size_of_val(this) + rest.len()`. + /// - `bytes.as_ptr().offset(size_of_val(this)) == rest.as_ptr()`. + fn split_bytes_by_mut( + bytes: &mut [u8], + ) -> Result<(&mut Self, &mut [u8]), ParseError>; +} + +unsafe impl SplitBytesByRef for u8 { + fn split_bytes_by_ref( + bytes: &[u8], + ) -> Result<(&Self, &[u8]), ParseError> { + bytes.split_first().ok_or(ParseError) + } + + fn split_bytes_by_mut( + bytes: &mut [u8], + ) -> Result<(&mut Self, &mut [u8]), ParseError> { + bytes.split_first_mut().ok_or(ParseError) + } +} + +unsafe impl SplitBytesByRef for [T; N] { + fn split_bytes_by_ref( + mut bytes: &[u8], + ) -> Result<(&Self, &[u8]), ParseError> { + let start = bytes.as_ptr(); + for _ in 0..N { + (_, bytes) = T::split_bytes_by_ref(bytes)?; + } + + // SAFETY: + // - 'T::split_bytes_by_ref()' was called 'N' times on successive + // positions, thus the original 'bytes' starts with 'N' instances + // of 'T' (even if 'T' is a ZST and so all instances overlap). + // - 'N' consecutive 'T's have the same layout as '[T; N]'. + // - Thus it is safe to cast 'start' to '[T; N]'. + // - The referenced data has the same lifetime as the output. + Ok((unsafe { &*start.cast::<[T; N]>() }, bytes)) + } + + fn split_bytes_by_mut( + mut bytes: &mut [u8], + ) -> Result<(&mut Self, &mut [u8]), ParseError> { + let start = bytes.as_mut_ptr(); + for _ in 0..N { + (_, bytes) = T::split_bytes_by_mut(bytes)?; + } + + // SAFETY: + // - 'T::split_bytes_by_ref()' was called 'N' times on successive + // positions, thus the original 'bytes' starts with 'N' instances + // of 'T' (even if 'T' is a ZST and so all instances overlap). + // - 'N' consecutive 'T's have the same layout as '[T; N]'. + // - Thus it is safe to cast 'start' to '[T; N]'. + // - The referenced data has the same lifetime as the output. + Ok((unsafe { &mut *start.cast::<[T; N]>() }, bytes)) + } +} + +/// Deriving [`SplitBytesByRef`] automatically. +/// +/// [`SplitBytesByRef`] can be derived on `struct`s (not `enum`s or `union`s), +/// where a fixed memory layout (`repr(C)` or `repr(transparent)`) is used. +/// All fields must implement [`SplitBytesByRef`]. +/// +/// Here's a simple example: +/// +/// ``` +/// # use domain::new_base::wire::{ParseBytesByRef, SplitBytesByRef, U32, ParseError}; +/// #[derive(ParseBytesByRef)] +/// #[repr(C)] +/// struct Foo { +/// a: U32, +/// b: Bar, +/// } +/// +/// # struct Bar { data: T } +/// +/// // The generated impl with 'derive(SplitBytesByRef)': +/// unsafe impl SplitBytesByRef for Foo +/// where Bar: SplitBytesByRef { +/// fn split_bytes_by_ref( +/// bytes: &[u8], +/// ) -> Result<(&Self, &[u8]), ParseError> { +/// let addr = bytes.as_ptr(); +/// let (_, bytes) = U32::split_bytes_by_ref(bytes)?; +/// let (last, bytes) = >::split_bytes_by_ref(bytes)?; +/// let this = last.ptr_with_address(addr as *const ()); +/// Ok((unsafe { &*(this as *const Self) }, bytes)) +/// } +/// +/// fn split_bytes_by_mut( +/// bytes: &mut [u8], +/// ) -> Result<(&mut Self, &mut [u8]), ParseError> { +/// let addr = bytes.as_ptr(); +/// let (_, bytes) = U32::split_bytes_by_mut(bytes)?; +/// let (last, bytes) = >::split_bytes_by_mut(bytes)?; +/// let this = last.ptr_with_address(addr as *const ()); +/// Ok((unsafe { &mut *(this as *const Self as *mut Self) }, bytes)) +/// } +/// } +/// ``` +pub use domain_macros::SplitBytesByRef; + +//----------- ParseError ----------------------------------------------------- + +/// A DNS message parsing error. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ParseError; + +//--- Formatting + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("DNS data could not be parsed from the wire format") + } +} diff --git a/src/new_edns/cookie.rs b/src/new_edns/cookie.rs index 1a815e615..36d96a4f4 100644 --- a/src/new_edns/cookie.rs +++ b/src/new_edns/cookie.rs @@ -18,7 +18,7 @@ use domain_macros::*; use crate::new_base::Serial; #[cfg(all(feature = "std", feature = "siphasher"))] -use crate::new_base::build::{AsBytes, TruncationError}; +use crate::new_base::wire::{AsBytes, TruncationError}; //----------- CookieRequest -------------------------------------------------- @@ -71,7 +71,7 @@ impl CookieRequest { use siphasher::sip::SipHasher24; - use crate::new_base::build::BuildBytes; + use crate::new_base::wire::BuildBytes; // Build and hash the cookie simultaneously. let mut hasher = SipHasher24::new_with_key(secret); diff --git a/src/new_edns/ext_err.rs b/src/new_edns/ext_err.rs index 6858aa713..7613afd8e 100644 --- a/src/new_edns/ext_err.rs +++ b/src/new_edns/ext_err.rs @@ -6,12 +6,12 @@ use core::fmt; use domain_macros::*; -use zerocopy::network_endian::U16; +use crate::new_base::wire::U16; //----------- ExtError ------------------------------------------------------- /// An extended DNS error. -#[derive(AsBytes, ParseBytesByRef)] +#[derive(AsBytes, BuildBytes, ParseBytesByRef)] #[repr(C)] pub struct ExtError { /// The error code. diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 3b3ab2a80..781224360 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -4,16 +4,14 @@ use core::{fmt, ops::Range}; -use zerocopy::{network_endian::U16, IntoBytes}; - use domain_macros::*; use crate::{ new_base::{ - build::{AsBytes, BuildBytes, TruncationError}, - parse::{ - ParseBytes, ParseBytesByRef, ParseError, ParseFromMessage, - SplitBytes, SplitFromMessage, + parse::{ParseFromMessage, SplitFromMessage}, + wire::{ + AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, + SplitBytes, TruncationError, U16, }, Message, }, diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index d9e8829ac..14c7cdc9f 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -10,15 +10,14 @@ use core::str::FromStr; #[cfg(feature = "std")] use std::net::Ipv4Addr; -use zerocopy::network_endian::{U16, U32}; - use domain_macros::*; use crate::new_base::{ - build::{self, AsBytes, BuildIntoMessage, TruncationError}, - parse::{ - ParseBytes, ParseError, ParseFromMessage, SplitBytes, - SplitFromMessage, + build::{self, BuildIntoMessage}, + parse::{ParseFromMessage, SplitFromMessage}, + wire::{ + AsBytes, ParseBytes, ParseError, SplitBytes, TruncationError, U16, + U32, }, CharStr, Message, Serial, }; @@ -386,8 +385,6 @@ impl<'a> ParseFromMessage<'a> for HInfo<'a> { message: &'a Message, range: Range, ) -> Result { - use zerocopy::IntoBytes; - message .as_bytes() .get(range) @@ -501,8 +498,6 @@ impl<'a> ParseFromMessage<'a> for &'a Txt { message: &'a Message, range: Range, ) -> Result { - use zerocopy::IntoBytes; - message .as_bytes() .get(range) diff --git a/src/new_rdata/ipv6.rs b/src/new_rdata/ipv6.rs index fb3f9d30e..788a1ca97 100644 --- a/src/new_rdata/ipv6.rs +++ b/src/new_rdata/ipv6.rs @@ -10,8 +10,9 @@ use std::net::Ipv6Addr; use domain_macros::*; -use crate::new_base::build::{ - self, AsBytes, BuildIntoMessage, TruncationError, +use crate::new_base::{ + build::{self, BuildIntoMessage}, + wire::{AsBytes, TruncationError}, }; //----------- Aaaa ----------------------------------------------------------- diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index ebdcc7743..23be1e18f 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -5,11 +5,9 @@ use core::ops::Range; use domain_macros::*; use crate::new_base::{ - build::{BuildBytes, BuildIntoMessage, Builder, TruncationError}, - parse::{ - ParseBytes, ParseError, ParseFromMessage, SplitBytes, - SplitFromMessage, - }, + build::{self, BuildIntoMessage}, + parse::{ParseFromMessage, SplitFromMessage}, + wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, Message, ParseRecordData, RType, }; @@ -131,7 +129,7 @@ where impl BuildIntoMessage for RecordData<'_, N> { fn build_into_message( &self, - builder: Builder<'_>, + builder: build::Builder<'_>, ) -> Result<(), TruncationError> { match self { Self::A(r) => r.build_into_message(builder), From af13cf14cb31684770205eb2f83ed4b79edbddf8 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 6 Jan 2025 18:34:51 +0100 Subject: [PATCH 072/167] [new_base] Correct docs for build traits --- src/new_base/build/mod.rs | 5 ++--- src/new_base/wire/build.rs | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 2faca3c16..35752d2e2 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -11,9 +11,8 @@ pub use super::wire::TruncationError; pub trait BuildIntoMessage { // Append this value to the DNS message. /// - /// If the byte string is long enough to fit the message, it is appended - /// using the given message builder and committed. Otherwise, a - /// [`TruncationError`] is returned. + /// If the builder has enough capacity to fit the message, it is appended + /// and committed. Otherwise, a [`TruncationError`] is returned. fn build_into_message( &self, builder: Builder<'_>, diff --git a/src/new_base/wire/build.rs b/src/new_base/wire/build.rs index 1b67a4d40..88b6a44b3 100644 --- a/src/new_base/wire/build.rs +++ b/src/new_base/wire/build.rs @@ -12,7 +12,7 @@ pub trait BuildBytes { /// buffer. If the buffer is large enough, the whole object is written /// and the remaining (unmodified) part of the buffer is returned. /// - /// if the buffer is too small, a [`TruncationError`] is returned (and + /// If the buffer is too small, a [`TruncationError`] is returned (and /// parts of the buffer may be modified). fn build_bytes<'b>( &self, From 8daf6b502b7c172fd9cccce55a29026f12f1e566 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 6 Jan 2025 18:42:25 +0100 Subject: [PATCH 073/167] [macros] Avoid glob imports where possible --- macros/src/data.rs | 2 +- macros/src/impls.rs | 58 ++++++++++++++++++++++++++------------------- macros/src/repr.rs | 12 +++++++--- 3 files changed, 43 insertions(+), 29 deletions(-) diff --git a/macros/src/data.rs b/macros/src/data.rs index 6a0788b3e..ee0c52baf 100644 --- a/macros/src/data.rs +++ b/macros/src/data.rs @@ -4,7 +4,7 @@ use std::ops::Deref; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; -use syn::{spanned::Spanned, *}; +use syn::{spanned::Spanned, Field, Fields, Ident, Index, Member, Token}; //----------- Struct --------------------------------------------------------- diff --git a/macros/src/impls.rs b/macros/src/impls.rs index 2d9724f0e..4c0971998 100644 --- a/macros/src/impls.rs +++ b/macros/src/impls.rs @@ -2,7 +2,11 @@ use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, ToTokens}; -use syn::{punctuated::Punctuated, visit::Visit, *}; +use syn::{ + punctuated::Punctuated, visit::Visit, ConstParam, GenericArgument, + GenericParam, Ident, Lifetime, LifetimeParam, Token, TypeParam, + TypeParamBound, WhereClause, WherePredicate, +}; //----------- ImplSkeleton --------------------------------------------------- @@ -21,24 +25,24 @@ pub struct ImplSkeleton { pub unsafety: Option, /// The trait being implemented. - pub bound: Option, + pub bound: Option, /// The type being implemented on. - pub subject: Path, + pub subject: syn::Path, /// The where clause of the `impl` block. pub where_clause: WhereClause, /// The contents of the `impl`. - pub contents: Block, + pub contents: syn::Block, /// A `const` block for asserting requirements. - pub requirements: Block, + pub requirements: syn::Block, } impl ImplSkeleton { /// Construct an [`ImplSkeleton`] for a [`DeriveInput`]. - pub fn new(input: &DeriveInput, unsafety: bool) -> Self { + pub fn new(input: &syn::DeriveInput, unsafety: bool) -> Self { let mut lifetimes = Vec::new(); let mut types = Vec::new(); let mut consts = Vec::new(); @@ -55,13 +59,13 @@ impl ImplSkeleton { GenericParam::Type(value) => { types.push(value.clone()); let id = value.ident.clone(); - let id = TypePath { + let id = syn::TypePath { qself: None, - path: Path { + path: syn::Path { leading_colon: None, - segments: [PathSegment { + segments: [syn::PathSegment { ident: id, - arguments: PathArguments::None, + arguments: syn::PathArguments::None, }] .into_iter() .collect(), @@ -73,13 +77,13 @@ impl ImplSkeleton { GenericParam::Const(value) => { consts.push(value.clone()); let id = value.ident.clone(); - let id = TypePath { + let id = syn::TypePath { qself: None, - path: Path { + path: syn::Path { leading_colon: None, - segments: [PathSegment { + segments: [syn::PathSegment { ident: id, - arguments: PathArguments::None, + arguments: syn::PathArguments::None, }] .into_iter() .collect(), @@ -92,12 +96,12 @@ impl ImplSkeleton { let unsafety = unsafety.then_some(::default()); - let subject = Path { + let subject = syn::Path { leading_colon: None, - segments: [PathSegment { + segments: [syn::PathSegment { ident: input.ident.clone(), - arguments: PathArguments::AngleBracketed( - AngleBracketedGenericArguments { + arguments: syn::PathArguments::AngleBracketed( + syn::AngleBracketedGenericArguments { colon2_token: None, lt_token: Default::default(), args: subject_args, @@ -115,12 +119,12 @@ impl ImplSkeleton { predicates: Punctuated::new(), }); - let contents = Block { + let contents = syn::Block { brace_token: Default::default(), stmts: Vec::new(), }; - let requirements = Block { + let requirements = syn::Block { brace_token: Default::default(), stmts: Vec::new(), }; @@ -142,7 +146,11 @@ impl ImplSkeleton { /// /// If the type is concrete, a verifying statement is added for it. /// Otherwise, it is added to the where clause. - pub fn require_bound(&mut self, target: Type, bound: TypeParamBound) { + pub fn require_bound( + &mut self, + target: syn::Type, + bound: TypeParamBound, + ) { let mut visitor = ConcretenessVisitor { skeleton: self, is_concrete: true, @@ -154,7 +162,7 @@ impl ImplSkeleton { if visitor.is_concrete { // Add a concrete requirement for this bound. - self.requirements.stmts.push(parse_quote! { + self.requirements.stmts.push(syn::parse_quote! { const _: fn() = || { fn assert_impl() {} assert_impl::<#target>(); @@ -164,7 +172,7 @@ impl ImplSkeleton { // Add this bound to the `where` clause. let mut bounds = Punctuated::new(); bounds.push(bound); - let pred = WherePredicate::Type(PredicateType { + let pred = WherePredicate::Type(syn::PredicateType { lifetimes: None, bounded_ty: target, colon_token: Default::default(), @@ -196,9 +204,9 @@ impl ImplSkeleton { let lifetime = self.new_lifetime(prefix); let mut bounds = bounds.into_iter().peekable(); let param = if bounds.peek().is_some() { - parse_quote! { #lifetime: #(#bounds)+* } + syn::parse_quote! { #lifetime: #(#bounds)+* } } else { - parse_quote! { #lifetime } + syn::parse_quote! { #lifetime } }; (lifetime, param) } diff --git a/macros/src/repr.rs b/macros/src/repr.rs index 80c900eb6..b699b571b 100644 --- a/macros/src/repr.rs +++ b/macros/src/repr.rs @@ -1,7 +1,10 @@ //! Determining the memory layout of a type. use proc_macro2::Span; -use syn::{punctuated::Punctuated, spanned::Spanned, *}; +use syn::{ + punctuated::Punctuated, spanned::Spanned, Attribute, Error, LitInt, Meta, + Token, +}; //----------- Repr ----------------------------------------------------------- @@ -19,7 +22,10 @@ impl Repr { /// Determine the representation for a type from its attributes. /// /// This will fail if a stable representation cannot be found. - pub fn determine(attrs: &[Attribute], bound: &str) -> Result { + pub fn determine( + attrs: &[Attribute], + bound: &str, + ) -> Result { let mut repr = None; for attr in attrs { if !attr.path().is_ident("repr") { @@ -57,7 +63,7 @@ impl Repr { || meta.path.is_ident("aligned") => { let span = meta.span(); - let lit: LitInt = parse2(meta.tokens)?; + let lit: LitInt = syn::parse2(meta.tokens)?; let n: usize = lit.base10_parse()?; if n != 1 { return Err(Error::new(span, From 8bf87bee84595a1e28ee3db1d828a51e72e192f2 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 6 Jan 2025 18:46:27 +0100 Subject: [PATCH 074/167] [macros/lib.rs] Remove the last 'syn' glob import --- macros/src/lib.rs | 122 +++++++++++++++++++++++----------------------- 1 file changed, 62 insertions(+), 60 deletions(-) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 3fa1bc18b..a23e6902a 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -5,7 +5,7 @@ use proc_macro as pm; use proc_macro2::TokenStream; use quote::{format_ident, ToTokens}; -use syn::*; +use syn::{Error, Ident, Result}; mod impls; use impls::ImplSkeleton; @@ -20,16 +20,16 @@ use repr::Repr; #[proc_macro_derive(SplitBytes)] pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { - fn inner(input: DeriveInput) -> Result { + fn inner(input: syn::DeriveInput) -> Result { let data = match &input.data { - Data::Struct(data) => data, - Data::Enum(data) => { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { return Err(Error::new_spanned( data.enum_token, "'SplitBytes' can only be 'derive'd for 'struct's", )); } - Data::Union(data) => { + syn::Data::Union(data) => { return Err(Error::new_spanned( data.union_token, "'SplitBytes' can only be 'derive'd for 'struct's", @@ -47,7 +47,7 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { ); skeleton.lifetimes.push(param); skeleton.bound = Some( - parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), + syn::parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), ); // Inspect the 'struct' fields. @@ -58,14 +58,14 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { for field in data.fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), + syn::parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), ); } // Define 'parse_bytes()'. let init_vars = builder.init_vars(); let tys = data.fields().map(|f| &f.ty); - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn split_bytes( bytes: & #lifetime [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -82,7 +82,7 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { Ok(skeleton.into_token_stream()) } - let input = syn::parse_macro_input!(input as DeriveInput); + let input = syn::parse_macro_input!(input as syn::DeriveInput); inner(input) .unwrap_or_else(syn::Error::into_compile_error) .into() @@ -92,16 +92,16 @@ pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream { #[proc_macro_derive(ParseBytes)] pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { - fn inner(input: DeriveInput) -> Result { + fn inner(input: syn::DeriveInput) -> Result { let data = match &input.data { - Data::Struct(data) => data, - Data::Enum(data) => { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { return Err(Error::new_spanned( data.enum_token, "'ParseBytes' can only be 'derive'd for 'struct's", )); } - Data::Union(data) => { + syn::Data::Union(data) => { return Err(Error::new_spanned( data.union_token, "'ParseBytes' can only be 'derive'd for 'struct's", @@ -119,7 +119,7 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { ); skeleton.lifetimes.push(param); skeleton.bound = Some( - parse_quote!(::domain::new_base::wire::ParseBytes<#lifetime>), + syn::parse_quote!(::domain::new_base::wire::ParseBytes<#lifetime>), ); // Inspect the 'struct' fields. @@ -130,19 +130,19 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { for field in data.sized_fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), + syn::parse_quote!(::domain::new_base::wire::SplitBytes<#lifetime>), ); } if let Some(field) = data.unsized_field() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::ParseBytes<#lifetime>), + syn::parse_quote!(::domain::new_base::wire::ParseBytes<#lifetime>), ); } // Finish early if the 'struct' has no fields. if data.is_empty() { - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn parse_bytes( bytes: & #lifetime [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -165,7 +165,7 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { let tys = builder.sized_fields().map(|f| &f.ty); let unsized_ty = &builder.unsized_field().unwrap().ty; let unsized_init_var = builder.unsized_init_var().unwrap(); - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn parse_bytes( bytes: & #lifetime [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -185,7 +185,7 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { Ok(skeleton.into_token_stream()) } - let input = syn::parse_macro_input!(input as DeriveInput); + let input = syn::parse_macro_input!(input as syn::DeriveInput); inner(input) .unwrap_or_else(syn::Error::into_compile_error) .into() @@ -195,16 +195,16 @@ pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream { #[proc_macro_derive(SplitBytesByRef)] pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { - fn inner(input: DeriveInput) -> Result { + fn inner(input: syn::DeriveInput) -> Result { let data = match &input.data { - Data::Struct(data) => data, - Data::Enum(data) => { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { return Err(Error::new_spanned( data.enum_token, "'SplitBytesByRef' can only be 'derive'd for 'struct's", )); } - Data::Union(data) => { + syn::Data::Union(data) => { return Err(Error::new_spanned( data.union_token, "'SplitBytesByRef' can only be 'derive'd for 'struct's", @@ -216,8 +216,9 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); - skeleton.bound = - Some(parse_quote!(::domain::new_base::wire::SplitBytesByRef)); + skeleton.bound = Some(syn::parse_quote!( + ::domain::new_base::wire::SplitBytesByRef + )); // Inspect the 'struct' fields. let data = Struct::new_as_self(&data.fields); @@ -226,13 +227,13 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { for field in data.fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::SplitBytesByRef), + syn::parse_quote!(::domain::new_base::wire::SplitBytesByRef), ); } // Finish early if the 'struct' has no fields. if data.is_empty() { - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn split_bytes_by_ref( bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -255,7 +256,7 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Define 'split_bytes_by_ref()'. let tys = data.sized_fields().map(|f| &f.ty); let unsized_ty = &data.unsized_field().unwrap().ty; - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn split_bytes_by_ref( bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -289,7 +290,7 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Define 'split_bytes_by_mut()'. let tys = data.sized_fields().map(|f| &f.ty); - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn split_bytes_by_mut( bytes: &mut [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -324,7 +325,7 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { Ok(skeleton.into_token_stream()) } - let input = syn::parse_macro_input!(input as DeriveInput); + let input = syn::parse_macro_input!(input as syn::DeriveInput); inner(input) .unwrap_or_else(syn::Error::into_compile_error) .into() @@ -334,16 +335,16 @@ pub fn derive_split_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { #[proc_macro_derive(ParseBytesByRef)] pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { - fn inner(input: DeriveInput) -> Result { + fn inner(input: syn::DeriveInput) -> Result { let data = match &input.data { - Data::Struct(data) => data, - Data::Enum(data) => { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { return Err(Error::new_spanned( data.enum_token, "'ParseBytesByRef' can only be 'derive'd for 'struct's", )); } - Data::Union(data) => { + syn::Data::Union(data) => { return Err(Error::new_spanned( data.union_token, "'ParseBytesByRef' can only be 'derive'd for 'struct's", @@ -355,8 +356,9 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); - skeleton.bound = - Some(parse_quote!(::domain::new_base::wire::ParseBytesByRef)); + skeleton.bound = Some(syn::parse_quote!( + ::domain::new_base::wire::ParseBytesByRef + )); // Inspect the 'struct' fields. let data = Struct::new_as_self(&data.fields); @@ -365,19 +367,19 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { for field in data.sized_fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::SplitBytesByRef), + syn::parse_quote!(::domain::new_base::wire::SplitBytesByRef), ); } if let Some(field) = data.unsized_field() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::ParseBytesByRef), + syn::parse_quote!(::domain::new_base::wire::ParseBytesByRef), ); } // Finish early if the 'struct' has no fields. if data.is_empty() { - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn parse_bytes_by_ref( bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -395,7 +397,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { } }); - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn ptr_with_address( &self, addr: *const (), @@ -410,7 +412,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Define 'parse_bytes_by_ref()'. let tys = data.sized_fields().map(|f| &f.ty); let unsized_ty = &data.unsized_field().unwrap().ty; - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn parse_bytes_by_ref( bytes: &[::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -444,7 +446,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Define 'parse_bytes_by_mut()'. let tys = data.sized_fields().map(|f| &f.ty); - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn parse_bytes_by_mut( bytes: &mut [::domain::__core::primitive::u8], ) -> ::domain::__core::result::Result< @@ -478,7 +480,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { // Define 'ptr_with_address()'. let unsized_member = data.unsized_member(); - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn ptr_with_address(&self, addr: *const ()) -> *const Self { <#unsized_ty as ::domain::new_base::wire::ParseBytesByRef> ::ptr_with_address(&self.#unsized_member, addr) @@ -489,7 +491,7 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { Ok(skeleton.into_token_stream()) } - let input = syn::parse_macro_input!(input as DeriveInput); + let input = syn::parse_macro_input!(input as syn::DeriveInput); inner(input) .unwrap_or_else(syn::Error::into_compile_error) .into() @@ -499,16 +501,16 @@ pub fn derive_parse_bytes_by_ref(input: pm::TokenStream) -> pm::TokenStream { #[proc_macro_derive(BuildBytes)] pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { - fn inner(input: DeriveInput) -> Result { + fn inner(input: syn::DeriveInput) -> Result { let data = match &input.data { - Data::Struct(data) => data, - Data::Enum(data) => { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { return Err(Error::new_spanned( data.enum_token, "'BuildBytes' can only be 'derive'd for 'struct's", )); } - Data::Union(data) => { + syn::Data::Union(data) => { return Err(Error::new_spanned( data.union_token, "'BuildBytes' can only be 'derive'd for 'struct's", @@ -519,7 +521,7 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, false); skeleton.bound = - Some(parse_quote!(::domain::new_base::wire::BuildBytes)); + Some(syn::parse_quote!(::domain::new_base::wire::BuildBytes)); // Inspect the 'struct' fields. let data = Struct::new_as_self(&data.fields); @@ -531,14 +533,14 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { for field in data.fields() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::BuildBytes), + syn::parse_quote!(::domain::new_base::wire::BuildBytes), ); } // Define 'build_bytes()'. let members = data.members(); let tys = data.fields().map(|f| &f.ty); - skeleton.contents.stmts.push(parse_quote! { + skeleton.contents.stmts.push(syn::parse_quote! { fn build_bytes<#lifetime>( &self, mut bytes: & #lifetime mut [::domain::__core::primitive::u8], @@ -555,7 +557,7 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { Ok(skeleton.into_token_stream()) } - let input = syn::parse_macro_input!(input as DeriveInput); + let input = syn::parse_macro_input!(input as syn::DeriveInput); inner(input) .unwrap_or_else(syn::Error::into_compile_error) .into() @@ -565,16 +567,16 @@ pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream { #[proc_macro_derive(AsBytes)] pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { - fn inner(input: DeriveInput) -> Result { + fn inner(input: syn::DeriveInput) -> Result { let data = match &input.data { - Data::Struct(data) => data, - Data::Enum(data) => { + syn::Data::Struct(data) => data, + syn::Data::Enum(data) => { return Err(Error::new_spanned( data.enum_token, "'AsBytes' can only be 'derive'd for 'struct's", )); } - Data::Union(data) => { + syn::Data::Union(data) => { return Err(Error::new_spanned( data.union_token, "'AsBytes' can only be 'derive'd for 'struct's", @@ -587,13 +589,13 @@ pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { // Construct an 'ImplSkeleton' so that we can add trait bounds. let mut skeleton = ImplSkeleton::new(&input, true); skeleton.bound = - Some(parse_quote!(::domain::new_base::wire::AsBytes)); + Some(syn::parse_quote!(::domain::new_base::wire::AsBytes)); // Establish bounds on the fields. for field in data.fields.iter() { skeleton.require_bound( field.ty.clone(), - parse_quote!(::domain::new_base::wire::AsBytes), + syn::parse_quote!(::domain::new_base::wire::AsBytes), ); } @@ -602,7 +604,7 @@ pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { Ok(skeleton.into_token_stream()) } - let input = syn::parse_macro_input!(input as DeriveInput); + let input = syn::parse_macro_input!(input as syn::DeriveInput); inner(input) .unwrap_or_else(syn::Error::into_compile_error) .into() @@ -611,6 +613,6 @@ pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { //----------- Utility Functions ---------------------------------------------- /// Add a `field_` prefix to member names. -fn field_prefixed(member: Member) -> Ident { +fn field_prefixed(member: syn::Member) -> Ident { format_ident!("field_{}", member) } From 8afc305d60d0d2be93a83a6afe3ec33a85b150b9 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 7 Jan 2025 12:09:17 +0100 Subject: [PATCH 075/167] [new_base/parse] Rework '{Parse,Split}FromMessage' Instead of passing the input selection as a range for parsing, the whole message is cut (using 'Message::slice_to()') and only the start is indicated. This ensures that we never cross the end of the range. It also implicitly dictates that compressed names are not allowed to reference future locations in messages. In addition, both the parsing traits now use offsets into the message _contents_ rather than the whole message. They can avoid 'as_bytes()' everywhere and have better guarantees of success. It also ensures the message header can never be selected for parsing. --- src/new_base/charstr.rs | 15 ++++++-------- src/new_base/message.rs | 17 +++++++++++++-- src/new_base/name/reversed.rs | 25 ++++++++++------------ src/new_base/parse/mod.rs | 36 ++++++++++++++++---------------- src/new_base/question.rs | 9 +++----- src/new_base/record.rs | 24 ++++++++------------- src/new_edns/mod.rs | 12 +++++------ src/new_rdata/basic.rs | 39 +++++++++++++++++------------------ src/new_rdata/mod.rs | 26 +++++++++++------------ 9 files changed, 99 insertions(+), 104 deletions(-) diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index 979e8be20..ce7e1fba7 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -1,14 +1,11 @@ //! DNS "character strings". -use core::{fmt, ops::Range}; +use core::fmt; use super::{ build::{self, BuildIntoMessage}, parse::{ParseFromMessage, SplitFromMessage}, - wire::{ - AsBytes, BuildBytes, ParseBytes, ParseError, SplitBytes, - TruncationError, - }, + wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, Message, }; @@ -28,7 +25,7 @@ impl<'a> SplitFromMessage<'a> for &'a CharStr { message: &'a Message, start: usize, ) -> Result<(Self, usize), ParseError> { - let bytes = &message.as_bytes()[start..]; + let bytes = message.contents.get(start..).ok_or(ParseError)?; let (this, rest) = Self::split_bytes(bytes)?; Ok((this, bytes.len() - rest.len())) } @@ -37,11 +34,11 @@ impl<'a> SplitFromMessage<'a> for &'a CharStr { impl<'a> ParseFromMessage<'a> for &'a CharStr { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { message - .as_bytes() - .get(range) + .contents + .get(start..) .ok_or(ParseError) .and_then(Self::parse_bytes) } diff --git a/src/new_base/message.rs b/src/new_base/message.rs index 9c27d384f..ab1aa903c 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -2,9 +2,9 @@ use core::fmt; -use domain_macros::{AsBytes, *}; +use domain_macros::*; -use super::wire::U16; +use super::wire::{AsBytes, ParseBytesByRef, U16}; //----------- Message -------------------------------------------------------- @@ -19,6 +19,19 @@ pub struct Message { pub contents: [u8], } +//--- Interaction + +impl Message { + /// Truncate the contents of this message to the given size. + /// + /// The returned value will have a `contents` field of the given size. + pub fn slice_to(&self, size: usize) -> &Self { + let bytes = &self.as_bytes()[..12 + size]; + Self::parse_bytes_by_ref(bytes) + .expect("A 12-or-more byte string is a valid 'Message'") + } +} + //----------- Header --------------------------------------------------------- /// A DNS message header. diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index ba7cdb8c6..6432219a3 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -5,16 +5,13 @@ use core::{ cmp::Ordering, fmt, hash::{Hash, Hasher}, - ops::{Deref, Range}, + ops::Deref, }; use crate::new_base::{ build::{self, BuildIntoMessage}, parse::{ParseFromMessage, SplitFromMessage}, - wire::{ - AsBytes, BuildBytes, ParseBytes, ParseError, SplitBytes, - TruncationError, - }, + wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, Message, }; @@ -255,13 +252,13 @@ impl<'a> SplitFromMessage<'a> for RevNameBuf { // disallow a name to point to data _after_ it. Standard name // compressors will never generate such pointers. - let message = message.as_bytes(); + let contents = &message.contents; let mut buffer = Self::empty(); // Perform the first iteration early, to catch the end of the name. - let bytes = message.get(start..).ok_or(ParseError)?; + let bytes = contents.get(start..).ok_or(ParseError)?; let (mut pointer, rest) = parse_segment(bytes, &mut buffer)?; - let orig_end = message.len() - rest.len(); + let orig_end = contents.len() - rest.len(); // Traverse compression pointers. let mut old_start = start; @@ -272,7 +269,7 @@ impl<'a> SplitFromMessage<'a> for RevNameBuf { } // Keep going, from the referenced position. - let bytes = message.get(start..).ok_or(ParseError)?; + let bytes = contents.get(start..).ok_or(ParseError)?; (pointer, _) = parse_segment(bytes, &mut buffer)?; old_start = start; continue; @@ -288,17 +285,17 @@ impl<'a> SplitFromMessage<'a> for RevNameBuf { impl<'a> ParseFromMessage<'a> for RevNameBuf { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { // See 'split_from_message()' for details. The only differences are // in the range of the first iteration, and the check that the first // iteration exactly covers the input range. - let message = message.as_bytes(); + let contents = &message.contents; let mut buffer = Self::empty(); // Perform the first iteration early, to catch the end of the name. - let bytes = message.get(range.clone()).ok_or(ParseError)?; + let bytes = contents.get(start..).ok_or(ParseError)?; let (mut pointer, rest) = parse_segment(bytes, &mut buffer)?; if !rest.is_empty() { @@ -307,7 +304,7 @@ impl<'a> ParseFromMessage<'a> for RevNameBuf { } // Traverse compression pointers. - let mut old_start = range.start; + let mut old_start = start; while let Some(start) = pointer.map(usize::from) { // Ensure the referenced position comes earlier. if start >= old_start { @@ -315,7 +312,7 @@ impl<'a> ParseFromMessage<'a> for RevNameBuf { } // Keep going, from the referenced position. - let bytes = message.get(start..).ok_or(ParseError)?; + let bytes = contents.get(start..).ok_or(ParseError)?; (pointer, _) = parse_segment(bytes, &mut buffer)?; old_start = start; continue; diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index d36dd9543..7e5a08d7f 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -1,7 +1,5 @@ //! Parsing DNS messages from the wire format. -use core::ops::Range; - mod message; pub use message::{MessagePart, ParseMessage, VisitMessagePart}; @@ -14,7 +12,7 @@ pub use record::{ParseRecord, ParseRecords, VisitRecord}; pub use super::wire::ParseError; use super::{ - wire::{AsBytes, ParseBytesByRef, SplitBytesByRef}, + wire::{ParseBytesByRef, SplitBytesByRef}, Message, }; @@ -22,11 +20,14 @@ use super::{ /// A type that can be parsed from a DNS message. pub trait SplitFromMessage<'a>: Sized + ParseFromMessage<'a> { - /// Parse a value of [`Self`] from the start of a byte string within a - /// particular DNS message. + /// Parse a value from the start of a byte string within a DNS message. + /// + /// The byte string to parse is `message.contents[start..]`. The previous + /// data in the message can be used for resolving compressed names. /// - /// If parsing is successful, the parsed value and the rest of the string - /// are returned. Otherwise, a [`ParseError`] is returned. + /// If parsing is successful, the parsed value and the offset for the rest + /// of the input are returned. If `len` bytes were parsed to form `self`, + /// `start + len` should be the returned offset. fn split_from_message( message: &'a Message, start: usize, @@ -35,14 +36,15 @@ pub trait SplitFromMessage<'a>: Sized + ParseFromMessage<'a> { /// A type that can be parsed from a string in a DNS message. pub trait ParseFromMessage<'a>: Sized { - /// Parse a value of [`Self`] from a byte string within a particular DNS - /// message. + /// Parse a value from a byte string within a DNS message. + /// + /// The byte string to parse is `message.contents[start..]`. The previous + /// data in the message can be used for resolving compressed names. /// - /// If parsing is successful, the parsed value is returned. Otherwise, a - /// [`ParseError`] is returned. + /// If parsing is successful, the parsed value is returned. fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result; } @@ -51,20 +53,18 @@ impl<'a, T: ?Sized + SplitBytesByRef> SplitFromMessage<'a> for &'a T { message: &'a Message, start: usize, ) -> Result<(Self, usize), ParseError> { - let message = message.as_bytes(); - let bytes = message.get(start..).ok_or(ParseError)?; + let bytes = message.contents.get(start..).ok_or(ParseError)?; let (this, rest) = T::split_bytes_by_ref(bytes)?; - Ok((this, message.len() - rest.len())) + Ok((this, bytes.len() - rest.len())) } } impl<'a, T: ?Sized + ParseBytesByRef> ParseFromMessage<'a> for &'a T { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - let message = message.as_bytes(); - let bytes = message.get(range).ok_or(ParseError)?; + let bytes = message.contents.get(start..).ok_or(ParseError)?; T::parse_bytes_by_ref(bytes) } } diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 0dad0910a..b961f0af7 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -1,7 +1,5 @@ //! DNS questions. -use core::ops::Range; - use domain_macros::*; use super::{ @@ -66,12 +64,11 @@ where { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - let (qname, rest) = N::split_from_message(message, range.start)?; + let (qname, rest) = N::split_from_message(message, start)?; let (&qtype, rest) = <&QType>::split_from_message(message, rest)?; - let &qclass = - <&QClass>::parse_from_message(message, rest..range.end)?; + let &qclass = <&QClass>::parse_from_message(message, rest)?; Ok(Self::new(qname, qtype, qclass)) } } diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 2d84e0934..5686f09f3 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -1,9 +1,6 @@ //! DNS records. -use core::{ - borrow::Borrow, - ops::{Deref, Range}, -}; +use core::{borrow::Borrow, ops::Deref}; use super::{ build::{self, BuildIntoMessage}, @@ -78,8 +75,9 @@ where let (&ttl, rest) = <&TTL>::split_from_message(message, rest)?; let (&size, rest) = <&U16>::split_from_message(message, rest)?; let size: usize = size.get().into(); - let rdata = if message.as_bytes().len() - rest >= size { - D::parse_record_data(message, rest..rest + size, rtype)? + let rdata = if message.contents.len() - rest >= size { + let message = message.slice_to(rest + size); + D::parse_record_data(message, rest, rtype)? } else { return Err(ParseError); }; @@ -95,15 +93,11 @@ where { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - let message = &message.as_bytes()[..range.end]; - let message = Message::parse_bytes_by_ref(message) - .expect("The input range ends past the message header"); + let (this, rest) = Self::split_from_message(message, start)?; - let (this, rest) = Self::split_from_message(message, range.start)?; - - if rest == range.end { + if rest == message.contents.len() { Ok(this) } else { Err(ParseError) @@ -334,10 +328,10 @@ pub trait ParseRecordData<'a>: Sized { /// Parse DNS record data of the given type from a DNS message. fn parse_record_data( message: &'a Message, - range: Range, + start: usize, rtype: RType, ) -> Result { - let bytes = message.as_bytes().get(range).ok_or(ParseError)?; + let bytes = message.contents.get(start..).ok_or(ParseError)?; Self::parse_record_data_bytes(bytes, rtype) } diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 781224360..9afd69167 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -2,7 +2,7 @@ //! //! See [RFC 6891](https://datatracker.ietf.org/doc/html/rfc6891). -use core::{fmt, ops::Range}; +use core::fmt; use domain_macros::*; @@ -54,20 +54,20 @@ impl<'a> SplitFromMessage<'a> for EdnsRecord<'a> { message: &'a Message, start: usize, ) -> Result<(Self, usize), ParseError> { - let bytes = message.as_bytes().get(start..).ok_or(ParseError)?; + let bytes = message.contents.get(start..).ok_or(ParseError)?; let (this, rest) = Self::split_bytes(bytes)?; - Ok((this, message.as_bytes().len() - rest.len())) + Ok((this, message.contents.len() - rest.len())) } } impl<'a> ParseFromMessage<'a> for EdnsRecord<'a> { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { message - .as_bytes() - .get(range) + .contents + .get(start..) .ok_or(ParseError) .and_then(Self::parse_bytes) } diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 14c7cdc9f..0f295ec5d 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -2,7 +2,7 @@ //! //! See [RFC 1035](https://datatracker.ietf.org/doc/html/rfc1035). -use core::{fmt, ops::Range}; +use core::fmt; #[cfg(feature = "std")] use core::str::FromStr; @@ -123,9 +123,9 @@ pub struct Ns { impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ns { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - N::parse_from_message(message, range).map(|name| Self { name }) + N::parse_from_message(message, start).map(|name| Self { name }) } } @@ -167,9 +167,9 @@ pub struct CName { impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for CName { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - N::parse_from_message(message, range).map(|name| Self { name }) + N::parse_from_message(message, start).map(|name| Self { name }) } } @@ -226,15 +226,15 @@ pub struct Soa { impl<'a, N: SplitFromMessage<'a>> ParseFromMessage<'a> for Soa { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - let (mname, rest) = N::split_from_message(message, range.start)?; + let (mname, rest) = N::split_from_message(message, start)?; let (rname, rest) = N::split_from_message(message, rest)?; let (&serial, rest) = <&Serial>::split_from_message(message, rest)?; let (&refresh, rest) = <&U32>::split_from_message(message, rest)?; let (&retry, rest) = <&U32>::split_from_message(message, rest)?; let (&expire, rest) = <&U32>::split_from_message(message, rest)?; - let &minimum = <&U32>::parse_from_message(message, rest..range.end)?; + let &minimum = <&U32>::parse_from_message(message, rest)?; Ok(Self { mname, @@ -349,9 +349,9 @@ pub struct Ptr { impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ptr { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - N::parse_from_message(message, range).map(|name| Self { name }) + N::parse_from_message(message, start).map(|name| Self { name }) } } @@ -383,11 +383,11 @@ pub struct HInfo<'a> { impl<'a> ParseFromMessage<'a> for HInfo<'a> { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { message - .as_bytes() - .get(range) + .contents + .get(start..) .ok_or(ParseError) .and_then(Self::parse_bytes) } @@ -437,11 +437,10 @@ pub struct Mx { impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Mx { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { - let (&preference, rest) = - <&U16>::split_from_message(message, range.start)?; - let exchange = N::parse_from_message(message, rest..range.end)?; + let (&preference, rest) = <&U16>::split_from_message(message, start)?; + let exchange = N::parse_from_message(message, rest)?; Ok(Self { preference, exchange, @@ -496,11 +495,11 @@ impl Txt { impl<'a> ParseFromMessage<'a> for &'a Txt { fn parse_from_message( message: &'a Message, - range: Range, + start: usize, ) -> Result { message - .as_bytes() - .get(range) + .contents + .get(start..) .ok_or(ParseError) .and_then(Self::parse_bytes) } diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 23be1e18f..7617e4ad5 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -1,7 +1,5 @@ //! Record data types. -use core::ops::Range; - use domain_macros::*; use crate::new_base::{ @@ -70,35 +68,35 @@ where { fn parse_record_data( message: &'a Message, - range: Range, + start: usize, rtype: RType, ) -> Result { match rtype { - RType::A => <&A>::parse_from_message(message, range).map(Self::A), - RType::NS => Ns::parse_from_message(message, range).map(Self::Ns), + RType::A => <&A>::parse_from_message(message, start).map(Self::A), + RType::NS => Ns::parse_from_message(message, start).map(Self::Ns), RType::CNAME => { - CName::parse_from_message(message, range).map(Self::CName) + CName::parse_from_message(message, start).map(Self::CName) } RType::SOA => { - Soa::parse_from_message(message, range).map(Self::Soa) + Soa::parse_from_message(message, start).map(Self::Soa) } RType::WKS => { - <&Wks>::parse_from_message(message, range).map(Self::Wks) + <&Wks>::parse_from_message(message, start).map(Self::Wks) } RType::PTR => { - Ptr::parse_from_message(message, range).map(Self::Ptr) + Ptr::parse_from_message(message, start).map(Self::Ptr) } RType::HINFO => { - HInfo::parse_from_message(message, range).map(Self::HInfo) + HInfo::parse_from_message(message, start).map(Self::HInfo) } - RType::MX => Mx::parse_from_message(message, range).map(Self::Mx), + RType::MX => Mx::parse_from_message(message, start).map(Self::Mx), RType::TXT => { - <&Txt>::parse_from_message(message, range).map(Self::Txt) + <&Txt>::parse_from_message(message, start).map(Self::Txt) } RType::AAAA => { - <&Aaaa>::parse_from_message(message, range).map(Self::Aaaa) + <&Aaaa>::parse_from_message(message, start).map(Self::Aaaa) } - _ => <&UnknownRecordData>::parse_from_message(message, range) + _ => <&UnknownRecordData>::parse_from_message(message, start) .map(|data| Self::Unknown(rtype, data)), } } From 1037ba05dd889e7afb3390ef8b0e1f76dacccf42 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 7 Jan 2025 12:12:45 +0100 Subject: [PATCH 076/167] [new_base/name] Delete unused 'parsed.rs' --- src/new_base/name/parsed.rs | 131 ------------------------------------ 1 file changed, 131 deletions(-) delete mode 100644 src/new_base/name/parsed.rs diff --git a/src/new_base/name/parsed.rs b/src/new_base/name/parsed.rs deleted file mode 100644 index abf592e5d..000000000 --- a/src/new_base/name/parsed.rs +++ /dev/null @@ -1,131 +0,0 @@ -//! Domain names encoded in DNS messages. - -use zerocopy_derive::*; - -use crate::new_base::parse::{ParseError, ParseFrom, SplitFrom}; - -//----------- ParsedName ----------------------------------------------------- - -/// A domain name in a DNS message. -#[derive(Debug, IntoBytes, Immutable, Unaligned)] -#[repr(transparent)] -pub struct ParsedName([u8]); - -//--- Constants - -impl ParsedName { - /// The maximum size of a parsed domain name in the wire format. - /// - /// This can occur if a compression pointer is used to point to a root - /// name, even though such a representation is longer than copying the - /// root label into the name. - pub const MAX_SIZE: usize = 256; - - /// The root name. - pub const ROOT: &'static Self = { - // SAFETY: A root label is the shortest valid name. - unsafe { Self::from_bytes_unchecked(&[0u8]) } - }; -} - -//--- Construction - -impl ParsedName { - /// Assume a byte string is a valid [`ParsedName`]. - /// - /// # Safety - /// - /// The byte string must be correctly encoded in the wire format, and - /// within the size restriction (256 bytes or fewer). It must end with a - /// root label or a compression pointer. - pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Self { - // SAFETY: 'ParsedName' is 'repr(transparent)' to '[u8]', so casting a - // '[u8]' into a 'ParsedName' is sound. - core::mem::transmute(bytes) - } -} - -//--- Inspection - -impl ParsedName { - /// The size of this name in the wire format. - #[allow(clippy::len_without_is_empty)] - pub const fn len(&self) -> usize { - self.0.len() - } - - /// Whether this is the root label. - pub const fn is_root(&self) -> bool { - self.0.len() == 1 - } - - /// Whether this is a compression pointer. - pub const fn is_pointer(&self) -> bool { - self.0.len() == 2 - } - - /// The wire format representation of the name. - pub const fn as_bytes(&self) -> &[u8] { - &self.0 - } -} - -//--- Parsing - -impl<'a> SplitFrom<'a> for &'a ParsedName { - fn split_from(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { - // Iterate through the labels in the name. - let mut index = 0usize; - loop { - if index >= ParsedName::MAX_SIZE || index >= bytes.len() { - return Err(ParseError); - } - let length = bytes[index]; - if length == 0 { - // This was the root label. - index += 1; - break; - } else if length < 0x40 { - // This was the length of the label. - index += 1 + length as usize; - } else if length >= 0xC0 { - // This was a compression pointer. - if index + 1 >= bytes.len() { - return Err(ParseError); - } - index += 2; - break; - } else { - // This was a reserved or deprecated label type. - return Err(ParseError); - } - } - - let (name, bytes) = bytes.split_at(index); - // SAFETY: 'bytes' has been confirmed to be correctly encoded. - Ok((unsafe { ParsedName::from_bytes_unchecked(name) }, bytes)) - } -} - -impl<'a> ParseFrom<'a> for &'a ParsedName { - fn parse_from(bytes: &'a [u8]) -> Result { - Self::split_from(bytes).and_then(|(name, rest)| { - rest.is_empty().then_some(name).ok_or(ParseError) - }) - } -} - -//--- Conversion to and from bytes - -impl AsRef<[u8]> for ParsedName { - /// The bytes in the name in the wire format. - fn as_ref(&self) -> &[u8] { - &self.0 - } -} - -impl<'a> From<&'a ParsedName> for &'a [u8] { - fn from(name: &'a ParsedName) -> Self { - name.as_bytes() - } -} From 5faf44d2739dae80f9b2bacc3536348ddb6757d9 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 7 Jan 2025 14:53:11 +0100 Subject: [PATCH 077/167] [new_base/wire] Define 'SizePrefixed' --- src/new_base/name/reversed.rs | 1 + src/new_base/record.rs | 59 +++---- src/new_base/wire/mod.rs | 5 +- src/new_base/wire/size_prefixed.rs | 248 +++++++++++++++++++++++++++++ src/new_edns/mod.rs | 96 ++++++----- 5 files changed, 317 insertions(+), 92 deletions(-) create mode 100644 src/new_base/wire/size_prefixed.rs diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 6432219a3..e851911f3 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -269,6 +269,7 @@ impl<'a> SplitFromMessage<'a> for RevNameBuf { } // Keep going, from the referenced position. + let start = start.checked_sub(12).ok_or(ParseError)?; let bytes = contents.get(start..).ok_or(ParseError)?; (pointer, _) = parse_segment(bytes, &mut buffer)?; old_start = start; diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 5686f09f3..f65d1c5d5 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -8,7 +8,7 @@ use super::{ parse::{ParseFromMessage, SplitFromMessage}, wire::{ AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, - SplitBytes, SplitBytesByRef, TruncationError, U16, U32, + SizePrefixed, SplitBytes, SplitBytesByRef, TruncationError, U16, U32, }, Message, }; @@ -73,16 +73,13 @@ where let (&rtype, rest) = <&RType>::split_from_message(message, rest)?; let (&rclass, rest) = <&RClass>::split_from_message(message, rest)?; let (&ttl, rest) = <&TTL>::split_from_message(message, rest)?; - let (&size, rest) = <&U16>::split_from_message(message, rest)?; - let size: usize = size.get().into(); - let rdata = if message.contents.len() - rest >= size { - let message = message.slice_to(rest + size); - D::parse_record_data(message, rest, rtype)? - } else { - return Err(ParseError); - }; + let rdata_start = rest; + let (_, rest) = + <&SizePrefixed<[u8]>>::split_from_message(message, rest)?; + let message = message.slice_to(rest); + let rdata = D::parse_record_data(message, rdata_start, rtype)?; - Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest + size)) + Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) } } @@ -95,13 +92,14 @@ where message: &'a Message, start: usize, ) -> Result { - let (this, rest) = Self::split_from_message(message, start)?; + let (rname, rest) = N::split_from_message(message, start)?; + let (&rtype, rest) = <&RType>::split_from_message(message, rest)?; + let (&rclass, rest) = <&RClass>::split_from_message(message, rest)?; + let (&ttl, rest) = <&TTL>::split_from_message(message, rest)?; + let _ = <&SizePrefixed<[u8]>>::parse_from_message(message, rest)?; + let rdata = D::parse_record_data(message, rest, rtype)?; - if rest == message.contents.len() { - Ok(this) - } else { - Err(ParseError) - } + Ok(Self::new(rname, rtype, rclass, ttl, rdata)) } } @@ -148,13 +146,7 @@ where let (rtype, rest) = RType::split_bytes(rest)?; let (rclass, rest) = RClass::split_bytes(rest)?; let (ttl, rest) = TTL::split_bytes(rest)?; - let (size, rest) = U16::split_bytes(rest)?; - let size: usize = size.get().into(); - if rest.len() < size { - return Err(ParseError); - } - - let (rdata, rest) = rest.split_at(size); + let (rdata, rest) = <&SizePrefixed<[u8]>>::split_bytes(rest)?; let rdata = D::parse_record_data_bytes(rdata, rtype)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) @@ -171,13 +163,8 @@ where let (rtype, rest) = RType::split_bytes(rest)?; let (rclass, rest) = RClass::split_bytes(rest)?; let (ttl, rest) = TTL::split_bytes(rest)?; - let (size, rest) = U16::split_bytes(rest)?; - let size: usize = size.get().into(); - if rest.len() != size { - return Err(ParseError); - } - - let rdata = D::parse_record_data_bytes(rest, rtype)?; + let rdata = <&SizePrefixed<[u8]>>::parse_bytes(rest)?; + let rdata = D::parse_record_data_bytes(rdata, rtype)?; Ok(Self::new(rname, rtype, rclass, ttl, rdata)) } @@ -198,17 +185,9 @@ where bytes = self.rtype.as_bytes().build_bytes(bytes)?; bytes = self.rclass.as_bytes().build_bytes(bytes)?; bytes = self.ttl.as_bytes().build_bytes(bytes)?; + bytes = SizePrefixed::new(&self.rdata).build_bytes(bytes)?; - let (size, bytes) = - U16::split_bytes_by_mut(bytes).map_err(|_| TruncationError)?; - let bytes_len = bytes.len(); - - let rest = self.rdata.build_bytes(bytes)?; - *size = u16::try_from(bytes_len - rest.len()) - .expect("the record data never exceeds 64KiB") - .into(); - - Ok(rest) + Ok(bytes) } } diff --git a/src/new_base/wire/mod.rs b/src/new_base/wire/mod.rs index 4d5be5c25..41f131af7 100644 --- a/src/new_base/wire/mod.rs +++ b/src/new_base/wire/mod.rs @@ -1,4 +1,4 @@ -//! The basic wire format of network protocols. +//! Low-level byte serialization. //! //! This is a low-level module providing simple and efficient mechanisms to //! parse data from and build data into byte sequences. It takes inspiration @@ -79,3 +79,6 @@ pub use parse::{ mod ints; pub use ints::{U16, U32, U64}; + +mod size_prefixed; +pub use size_prefixed::SizePrefixed; diff --git a/src/new_base/wire/size_prefixed.rs b/src/new_base/wire/size_prefixed.rs new file mode 100644 index 000000000..5e4fc217e --- /dev/null +++ b/src/new_base/wire/size_prefixed.rs @@ -0,0 +1,248 @@ +//! Working with (U16-)size-prefixed data. + +use core::{ + borrow::{Borrow, BorrowMut}, + ops::{Deref, DerefMut}, +}; + +use super::{ + AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, SplitBytes, + SplitBytesByRef, TruncationError, U16, +}; + +//----------- SizePrefixed --------------------------------------------------- + +/// A wrapper adding a 16-bit size prefix to a message. +/// +/// This is a common element in DNS messages (e.g. for record data and EDNS +/// options). When serialized as bytes, the inner value is prefixed with a +/// 16-bit network-endian integer indicating the length of the inner value in +/// bytes. +#[derive(Copy, Clone)] +#[repr(C)] +pub struct SizePrefixed { + /// The size prefix (needed for 'ParseBytesByRef' / 'AsBytes'). + /// + /// This value is always consistent with the size of 'data' if it is + /// (de)serialized in-place. By the bounds on 'ParseBytesByRef' and + /// 'AsBytes', the serialized size is the same as 'size_of_val(&data)'. + size: U16, + + /// The inner data. + data: T, +} + +//--- Construction + +impl SizePrefixed { + const VALID_SIZE: () = assert!(core::mem::size_of::() < 65536); + + /// Construct a [`SizePrefixed`]. + /// + /// # Panics + /// + /// Panics if the data is 64KiB or more in size. + pub const fn new(data: T) -> Self { + // Force the 'VALID_SIZE' assertion to be evaluated. + #[allow(clippy::let_unit_value)] + let _ = Self::VALID_SIZE; + + Self { + size: U16::new(core::mem::size_of::() as u16), + data, + } + } +} + +//--- Conversion from the inner data + +impl From for SizePrefixed { + fn from(value: T) -> Self { + Self::new(value) + } +} + +//--- Access to the inner data + +impl Deref for SizePrefixed { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.data + } +} + +impl DerefMut for SizePrefixed { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.data + } +} + +impl Borrow for SizePrefixed { + fn borrow(&self) -> &T { + &self.data + } +} + +impl BorrowMut for SizePrefixed { + fn borrow_mut(&mut self) -> &mut T { + &mut self.data + } +} + +impl AsRef for SizePrefixed { + fn as_ref(&self) -> &T { + &self.data + } +} + +impl AsMut for SizePrefixed { + fn as_mut(&mut self) -> &mut T { + &mut self.data + } +} + +//--- Parsing from bytes + +impl<'b, T: ParseBytes<'b>> ParseBytes<'b> for SizePrefixed { + fn parse_bytes(bytes: &'b [u8]) -> Result { + let (size, rest) = U16::split_bytes(bytes)?; + if rest.len() != size.get() as usize { + return Err(ParseError); + } + let data = T::parse_bytes(bytes)?; + Ok(Self { size, data }) + } +} + +impl<'b, T: ParseBytes<'b>> SplitBytes<'b> for SizePrefixed { + fn split_bytes(bytes: &'b [u8]) -> Result<(Self, &'b [u8]), ParseError> { + let (size, rest) = U16::split_bytes(bytes)?; + if rest.len() < size.get() as usize { + return Err(ParseError); + } + let (data, rest) = rest.split_at(size.get() as usize); + let data = T::parse_bytes(data)?; + Ok((Self { size, data }, rest)) + } +} + +unsafe impl ParseBytesByRef for SizePrefixed { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + let addr = bytes.as_ptr(); + let (size, rest) = U16::split_bytes_by_ref(bytes)?; + if rest.len() != size.get() as usize { + return Err(ParseError); + } + let last = T::parse_bytes_by_ref(rest)?; + let ptr = last.ptr_with_address(addr as *const ()); + + // SAFETY: + // - 'bytes' is a 'U16' followed by a 'T'. + // - 'T' is 'ParseBytesByRef' and so is unaligned. + // - 'Self' is 'repr(C)' and so has no alignment or padding. + // - The layout of 'Self' is identical to '(U16, T)'. + Ok(unsafe { &*(ptr as *const Self) }) + } + + fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError> { + let addr = bytes.as_ptr(); + let (size, rest) = U16::split_bytes_by_mut(bytes)?; + if rest.len() != size.get() as usize { + return Err(ParseError); + } + let last = T::parse_bytes_by_mut(rest)?; + let ptr = last.ptr_with_address(addr as *const ()); + + // SAFETY: + // - 'bytes' is a 'U16' followed by a 'T'. + // - 'T' is 'ParseBytesByRef' and so is unaligned. + // - 'Self' is 'repr(C)' and so has no alignment or padding. + // - The layout of 'Self' is identical to '(U16, T)'. + Ok(unsafe { &mut *(ptr as *const Self as *mut Self) }) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + self.data.ptr_with_address(addr) as *const Self + } +} + +unsafe impl SplitBytesByRef for SizePrefixed { + fn split_bytes_by_ref( + bytes: &[u8], + ) -> Result<(&Self, &[u8]), ParseError> { + let addr = bytes.as_ptr(); + let (size, rest) = U16::split_bytes_by_ref(bytes)?; + if rest.len() < size.get() as usize { + return Err(ParseError); + } + let (data, rest) = rest.split_at(size.get() as usize); + let last = T::parse_bytes_by_ref(data)?; + let ptr = last.ptr_with_address(addr as *const ()); + + // SAFETY: + // - 'bytes' is a 'U16' followed by a 'T'. + // - 'T' is 'ParseBytesByRef' and so is unaligned. + // - 'Self' is 'repr(C)' and so has no alignment or padding. + // - The layout of 'Self' is identical to '(U16, T)'. + Ok((unsafe { &*(ptr as *const Self) }, rest)) + } + + fn split_bytes_by_mut( + bytes: &mut [u8], + ) -> Result<(&mut Self, &mut [u8]), ParseError> { + let addr = bytes.as_ptr(); + let (size, rest) = U16::split_bytes_by_mut(bytes)?; + if rest.len() < size.get() as usize { + return Err(ParseError); + } + let (data, rest) = rest.split_at_mut(size.get() as usize); + let last = T::parse_bytes_by_mut(data)?; + let ptr = last.ptr_with_address(addr as *const ()); + + // SAFETY: + // - 'bytes' is a 'U16' followed by a 'T'. + // - 'T' is 'ParseBytesByRef' and so is unaligned. + // - 'Self' is 'repr(C)' and so has no alignment or padding. + // - The layout of 'Self' is identical to '(U16, T)'. + Ok((unsafe { &mut *(ptr as *const Self as *mut Self) }, rest)) + } +} + +//--- Building into byte strings + +impl BuildBytes for SizePrefixed { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + // Get the size area to fill in afterwards. + let (size_buf, data_buf) = + U16::split_bytes_by_mut(bytes).map_err(|_| TruncationError)?; + let data_buf_len = data_buf.len(); + let rest = self.data.build_bytes(data_buf)?; + let size = data_buf_len - rest.len(); + assert!(size < 65536, "Cannot serialize >=64KiB into 16-bit integer"); + *size_buf = U16::new(size as u16); + Ok(rest) + } +} + +unsafe impl AsBytes for SizePrefixed { + // For debugging, we check that the serialized size is correct. + #[cfg(debug_assertions)] + fn as_bytes(&self) -> &[u8] { + let size: usize = self.size.get().into(); + assert_eq!(size, core::mem::size_of_val(&self.data)); + + // SAFETY: + // - 'Self' has no padding bytes and no interior mutability. + // - Its size in memory is exactly 'size_of_val(self)'. + unsafe { + core::slice::from_raw_parts( + self as *const Self as *const u8, + core::mem::size_of_val(self), + ) + } + } +} diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 9afd69167..152cd5dae 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -11,7 +11,7 @@ use crate::{ parse::{ParseFromMessage, SplitFromMessage}, wire::{ AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, - SplitBytes, TruncationError, U16, + SizePrefixed, SplitBytes, TruncationError, U16, }, Message, }, @@ -44,7 +44,7 @@ pub struct EdnsRecord<'a> { pub flags: EdnsFlags, /// Extended DNS options. - pub options: &'a Opt, + pub options: SizePrefixed<&'a Opt>, } //--- Parsing from DNS messages @@ -84,15 +84,7 @@ impl<'a> SplitBytes<'a> for EdnsRecord<'a> { let (&ext_rcode, rest) = <&u8>::split_bytes(rest)?; let (&version, rest) = <&u8>::split_bytes(rest)?; let (&flags, rest) = <&EdnsFlags>::split_bytes(rest)?; - - // Split the record size and data. - let (&size, rest) = <&U16>::split_bytes(rest)?; - let size: usize = size.get().into(); - if rest.len() < size { - return Err(ParseError); - } - let (options, rest) = rest.split_at(size); - let options = Opt::parse_bytes_by_ref(options)?; + let (options, rest) = >::split_bytes(rest)?; Ok(( Self { @@ -116,14 +108,7 @@ impl<'a> ParseBytes<'a> for EdnsRecord<'a> { let (&ext_rcode, rest) = <&u8>::split_bytes(rest)?; let (&version, rest) = <&u8>::split_bytes(rest)?; let (&flags, rest) = <&EdnsFlags>::split_bytes(rest)?; - - // Split the record size and data. - let (&size, rest) = <&U16>::split_bytes(rest)?; - let size: usize = size.get().into(); - if rest.len() != size { - return Err(ParseError); - } - let options = Opt::parse_bytes_by_ref(rest)?; + let options = >::parse_bytes(rest)?; Ok(Self { max_udp_payload, @@ -135,6 +120,26 @@ impl<'a> ParseBytes<'a> for EdnsRecord<'a> { } } +//--- Building into bytes + +impl BuildBytes for EdnsRecord<'_> { + fn build_bytes<'b>( + &self, + mut bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + // Add the record name (root) and the record type. + bytes = [0, 0, 41].as_slice().build_bytes(bytes)?; + + bytes = self.max_udp_payload.build_bytes(bytes)?; + bytes = self.ext_rcode.build_bytes(bytes)?; + bytes = self.version.build_bytes(bytes)?; + bytes = self.flags.build_bytes(bytes)?; + bytes = self.options.build_bytes(bytes)?; + + Ok(bytes) + } +} + //----------- EdnsFlags ------------------------------------------------------ /// Extended DNS flags describing a message. @@ -239,25 +244,22 @@ impl EdnsOption<'_> { impl<'b> ParseBytes<'b> for EdnsOption<'b> { fn parse_bytes(bytes: &'b [u8]) -> Result { let (code, rest) = OptionCode::split_bytes(bytes)?; - let (size, rest) = U16::split_bytes(rest)?; - if rest.len() != size.get() as usize { - return Err(ParseError); - } + let data = <&SizePrefixed<[u8]>>::parse_bytes(rest)?; match code { - OptionCode::COOKIE => match size.get() { - 8 => CookieRequest::parse_bytes_by_ref(rest) + OptionCode::COOKIE => match data.len() { + 8 => CookieRequest::parse_bytes_by_ref(data) .map(Self::CookieRequest), - 16..=40 => Cookie::parse_bytes_by_ref(rest).map(Self::Cookie), + 16..=40 => Cookie::parse_bytes_by_ref(data).map(Self::Cookie), _ => Err(ParseError), }, OptionCode::EXT_ERROR => { - ExtError::parse_bytes_by_ref(rest).map(Self::ExtError) + ExtError::parse_bytes_by_ref(data).map(Self::ExtError) } _ => { - let data = UnknownOption::parse_bytes_by_ref(rest)?; + let data = UnknownOption::parse_bytes_by_ref(data)?; Ok(Self::Unknown(code, data)) } } @@ -267,32 +269,25 @@ impl<'b> ParseBytes<'b> for EdnsOption<'b> { impl<'b> SplitBytes<'b> for EdnsOption<'b> { fn split_bytes(bytes: &'b [u8]) -> Result<(Self, &'b [u8]), ParseError> { let (code, rest) = OptionCode::split_bytes(bytes)?; - let (size, rest) = U16::split_bytes(rest)?; - if rest.len() < size.get() as usize { - return Err(ParseError); - } - let (bytes, rest) = rest.split_at(size.get() as usize); - - match code { - OptionCode::COOKIE => match size.get() { - 8 => CookieRequest::parse_bytes_by_ref(bytes) - .map(Self::CookieRequest), - 16..=40 => { - Cookie::parse_bytes_by_ref(bytes).map(Self::Cookie) - } - _ => Err(ParseError), + let (data, rest) = <&SizePrefixed<[u8]>>::split_bytes(rest)?; + + let this = match code { + OptionCode::COOKIE => match data.len() { + 8 => <&CookieRequest>::parse_bytes(data) + .map(Self::CookieRequest)?, + 16..=40 => <&Cookie>::parse_bytes(data).map(Self::Cookie)?, + _ => return Err(ParseError), }, OptionCode::EXT_ERROR => { - ExtError::parse_bytes_by_ref(bytes).map(Self::ExtError) + <&ExtError>::parse_bytes(data).map(Self::ExtError)? } - _ => { - let data = UnknownOption::parse_bytes_by_ref(bytes)?; - Ok(Self::Unknown(code, data)) - } - } - .map(|this| (this, rest)) + _ => <&UnknownOption>::parse_bytes(data) + .map(|data| Self::Unknown(code, data))?, + }; + + Ok((this, rest)) } } @@ -311,9 +306,8 @@ impl BuildBytes for EdnsOption<'_> { Self::ExtError(this) => this.as_bytes(), Self::Unknown(_, this) => this.as_bytes(), }; + bytes = SizePrefixed::new(data).build_bytes(bytes)?; - bytes = U16::new(data.len() as u16).build_bytes(bytes)?; - bytes = data.build_bytes(bytes)?; Ok(bytes) } } From a98d246e5bab2fd192565ed09814491ff2a69338 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 7 Jan 2025 15:33:18 +0100 Subject: [PATCH 078/167] [new_rdata/edns] Implement iteration and formatting --- src/new_base/record.rs | 29 ++++++++----- src/new_rdata/edns.rs | 93 +++++++++++++++++++++++++++++++++++++++++- src/new_rdata/mod.rs | 15 +++++-- 3 files changed, 122 insertions(+), 15 deletions(-) diff --git a/src/new_base/record.rs b/src/new_base/record.rs index f65d1c5d5..8fa3f30f4 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -219,35 +219,44 @@ pub struct RType { //--- Associated Constants impl RType { + const fn new(value: u16) -> Self { + Self { + code: U16::new(value), + } + } + /// The type of an [`A`](crate::new_rdata::A) record. - pub const A: Self = Self { code: U16::new(1) }; + pub const A: Self = Self::new(1); /// The type of an [`Ns`](crate::new_rdata::Ns) record. - pub const NS: Self = Self { code: U16::new(2) }; + pub const NS: Self = Self::new(2); /// The type of a [`CName`](crate::new_rdata::CName) record. - pub const CNAME: Self = Self { code: U16::new(5) }; + pub const CNAME: Self = Self::new(5); /// The type of an [`Soa`](crate::new_rdata::Soa) record. - pub const SOA: Self = Self { code: U16::new(6) }; + pub const SOA: Self = Self::new(6); /// The type of a [`Wks`](crate::new_rdata::Wks) record. - pub const WKS: Self = Self { code: U16::new(11) }; + pub const WKS: Self = Self::new(11); /// The type of a [`Ptr`](crate::new_rdata::Ptr) record. - pub const PTR: Self = Self { code: U16::new(12) }; + pub const PTR: Self = Self::new(12); /// The type of a [`HInfo`](crate::new_rdata::HInfo) record. - pub const HINFO: Self = Self { code: U16::new(13) }; + pub const HINFO: Self = Self::new(13); /// The type of a [`Mx`](crate::new_rdata::Mx) record. - pub const MX: Self = Self { code: U16::new(15) }; + pub const MX: Self = Self::new(15); /// The type of a [`Txt`](crate::new_rdata::Txt) record. - pub const TXT: Self = Self { code: U16::new(16) }; + pub const TXT: Self = Self::new(16); /// The type of an [`Aaaa`](crate::new_rdata::Aaaa) record. - pub const AAAA: Self = Self { code: U16::new(28) }; + pub const AAAA: Self = Self::new(28); + + /// The type of an [`Opt`](crate::new_rdata::Opt) record. + pub const OPT: Self = Self::new(41); } //----------- RClass --------------------------------------------------------- diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index c53a715a7..5c2ccce29 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -2,9 +2,17 @@ //! //! See [RFC 6891](https://datatracker.ietf.org/doc/html/rfc6891). +use core::{fmt, iter::FusedIterator}; + use domain_macros::*; -use crate::new_base::build::{self, BuildIntoMessage, TruncationError}; +use crate::{ + new_base::{ + build::{self, BuildIntoMessage, TruncationError}, + wire::{ParseError, SplitBytes}, + }, + new_edns::EdnsOption, +}; //----------- Opt ------------------------------------------------------------ @@ -18,7 +26,23 @@ pub struct Opt { contents: [u8], } -// TODO: Parsing the EDNS options. +//--- Inspection + +impl Opt { + /// Traverse the options in this record. + pub fn options(&self) -> EdnsOptionsIter<'_> { + EdnsOptionsIter::new(&self.contents) + } +} + +//--- Formatting + +impl fmt::Debug for Opt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Opt").field(&self.options()).finish() + } +} + // TODO: Formatting. //--- Building into DNS messages @@ -31,3 +55,68 @@ impl BuildIntoMessage for Opt { self.contents.build_into_message(builder) } } + +//----------- EdnsOptionsIter ------------------------------------------------ + +/// An iterator over EDNS options in an [`Opt`] record. +#[derive(Clone)] +pub struct EdnsOptionsIter<'a> { + /// The serialized options to parse from. + options: &'a [u8], +} + +//--- Construction + +impl<'a> EdnsOptionsIter<'a> { + /// Construct a new [`EdnsOptionsIter`]. + pub const fn new(options: &'a [u8]) -> Self { + Self { options } + } +} + +//--- Inspection + +impl<'a> EdnsOptionsIter<'a> { + /// The serialized options yet to be parsed. + pub const fn remaining(&self) -> &'a [u8] { + self.options + } +} + +//--- Formatting + +impl fmt::Debug for EdnsOptionsIter<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut entries = f.debug_set(); + for option in self.clone() { + match option { + Ok(option) => entries.entry(&option), + Err(_err) => entries.entry(&format_args!("")), + }; + } + entries.finish() + } +} + +//--- Iteration + +impl<'a> Iterator for EdnsOptionsIter<'a> { + type Item = Result, ParseError>; + + fn next(&mut self) -> Option { + if !self.options.is_empty() { + let options = core::mem::take(&mut self.options); + match EdnsOption::split_bytes(options) { + Ok((option, rest)) => { + self.options = rest; + Some(Ok(option)) + } + Err(err) => Some(Err(err)), + } + } else { + None + } + } +} + +impl FusedIterator for EdnsOptionsIter<'_> {} diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 7617e4ad5..ac1804ea8 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -18,7 +18,7 @@ mod ipv6; pub use ipv6::Aaaa; mod edns; -pub use edns::Opt; +pub use edns::{EdnsOptionsIter, Opt}; //----------- RecordData ----------------------------------------------------- @@ -56,6 +56,9 @@ pub enum RecordData<'a, N> { /// The IPv6 address of a host responsible for this domain. Aaaa(&'a Aaaa), + /// Extended DNS options. + Opt(&'a Opt), + /// Data for an unknown DNS record type. Unknown(RType, &'a UnknownRecordData), } @@ -96,6 +99,9 @@ where RType::AAAA => { <&Aaaa>::parse_from_message(message, start).map(Self::Aaaa) } + RType::OPT => { + <&Opt>::parse_from_message(message, start).map(Self::Opt) + } _ => <&UnknownRecordData>::parse_from_message(message, start) .map(|data| Self::Unknown(rtype, data)), } @@ -116,6 +122,7 @@ where RType::MX => Mx::parse_bytes(bytes).map(Self::Mx), RType::TXT => <&Txt>::parse_bytes(bytes).map(Self::Txt), RType::AAAA => <&Aaaa>::parse_bytes(bytes).map(Self::Aaaa), + RType::OPT => <&Opt>::parse_bytes(bytes).map(Self::Opt), _ => <&UnknownRecordData>::parse_bytes(bytes) .map(|data| Self::Unknown(rtype, data)), } @@ -137,9 +144,10 @@ impl BuildIntoMessage for RecordData<'_, N> { Self::Wks(r) => r.build_into_message(builder), Self::Ptr(r) => r.build_into_message(builder), Self::HInfo(r) => r.build_into_message(builder), + Self::Mx(r) => r.build_into_message(builder), Self::Txt(r) => r.build_into_message(builder), Self::Aaaa(r) => r.build_into_message(builder), - Self::Mx(r) => r.build_into_message(builder), + Self::Opt(r) => r.build_into_message(builder), Self::Unknown(_, r) => r.octets.build_into_message(builder), } } @@ -158,9 +166,10 @@ impl BuildBytes for RecordData<'_, N> { Self::Wks(r) => r.build_bytes(bytes), Self::Ptr(r) => r.build_bytes(bytes), Self::HInfo(r) => r.build_bytes(bytes), + Self::Mx(r) => r.build_bytes(bytes), Self::Txt(r) => r.build_bytes(bytes), Self::Aaaa(r) => r.build_bytes(bytes), - Self::Mx(r) => r.build_bytes(bytes), + Self::Opt(r) => r.build_bytes(bytes), Self::Unknown(_, r) => r.build_bytes(bytes), } } From 3ea33ac19fa5f097381dc16b42ddecb2c928f262 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 9 Jan 2025 08:12:37 +0100 Subject: [PATCH 079/167] [new_base/build] Use 'BuildResult' to ensure 'commit()' is called This has already caught a missing commit (for 'Question'). --- src/new_base/build/builder.rs | 9 +++++- src/new_base/build/mod.rs | 34 +++++++++++++---------- src/new_base/charstr.rs | 7 ++--- src/new_base/name/reversed.rs | 12 +++----- src/new_base/question.rs | 8 +++--- src/new_base/record.rs | 12 +++----- src/new_rdata/basic.rs | 52 ++++++++++------------------------- src/new_rdata/edns.rs | 7 ++--- src/new_rdata/ipv6.rs | 9 ++---- src/new_rdata/mod.rs | 7 ++--- 10 files changed, 63 insertions(+), 94 deletions(-) diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index e02da91db..8a0bbba33 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -12,6 +12,8 @@ use crate::new_base::{ Header, Message, }; +use super::BuildCommitted; + //----------- Builder -------------------------------------------------------- /// A DNS message builder. @@ -214,8 +216,13 @@ impl Builder<'_> { } /// Commit all appended content. - pub fn commit(&mut self) { + /// + /// For convenience, a unit type [`BuildCommitted`] is returned; it is + /// used as the return type of build functions to remind users to call + /// this method on success paths. + pub fn commit(&mut self) -> BuildCommitted { self.commit = self.context.size; + BuildCommitted } /// Mark bytes in the buffer as initialized. diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 35752d2e2..86e21ddfc 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -3,7 +3,7 @@ mod builder; pub use builder::{Builder, BuilderContext}; -pub use super::wire::TruncationError; +use super::wire::TruncationError; //----------- Message-aware building traits ---------------------------------- @@ -13,28 +13,32 @@ pub trait BuildIntoMessage { /// /// If the builder has enough capacity to fit the message, it is appended /// and committed. Otherwise, a [`TruncationError`] is returned. - fn build_into_message( - &self, - builder: Builder<'_>, - ) -> Result<(), TruncationError>; + fn build_into_message(&self, builder: Builder<'_>) -> BuildResult; } impl BuildIntoMessage for &T { - fn build_into_message( - &self, - builder: Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: Builder<'_>) -> BuildResult { (**self).build_into_message(builder) } } impl BuildIntoMessage for [u8] { - fn build_into_message( - &self, - mut builder: Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, mut builder: Builder<'_>) -> BuildResult { builder.append_bytes(self)?; - builder.commit(); - Ok(()) + Ok(builder.commit()) } } + +//----------- BuildResult ---------------------------------------------------- + +/// The result of building into a DNS message. +pub type BuildResult = Result; + +//----------- BuildCommitted ------------------------------------------------- + +/// The output of [`Builder::commit()`]. +/// +/// This is a stub type to remind users to call [`Builder::commit()`] in all +/// success paths of building functions. +#[derive(Debug)] +pub struct BuildCommitted; diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index ce7e1fba7..8df3c3d7c 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -3,7 +3,7 @@ use core::fmt; use super::{ - build::{self, BuildIntoMessage}, + build::{self, BuildIntoMessage, BuildResult}, parse::{ParseFromMessage, SplitFromMessage}, wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, Message, @@ -50,11 +50,10 @@ impl BuildIntoMessage for CharStr { fn build_into_message( &self, mut builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + ) -> BuildResult { builder.append_bytes(&[self.octets.len() as u8])?; builder.append_bytes(&self.octets)?; - builder.commit(); - Ok(()) + Ok(builder.commit()) } } diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index e851911f3..aa58b24cf 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -9,7 +9,7 @@ use core::{ }; use crate::new_base::{ - build::{self, BuildIntoMessage}, + build::{self, BuildIntoMessage, BuildResult}, parse::{ParseFromMessage, SplitFromMessage}, wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, Message, @@ -99,10 +99,9 @@ impl BuildIntoMessage for RevName { fn build_into_message( &self, mut builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + ) -> BuildResult { builder.append_name(self)?; - builder.commit(); - Ok(()) + Ok(builder.commit()) } } @@ -371,10 +370,7 @@ fn parse_segment<'a>( //--- Building into DNS messages impl BuildIntoMessage for RevNameBuf { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { (**self).build_into_message(builder) } } diff --git a/src/new_base/question.rs b/src/new_base/question.rs index b961f0af7..720d46e14 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -3,10 +3,10 @@ use domain_macros::*; use super::{ - build::{self, BuildIntoMessage}, + build::{self, BuildIntoMessage, BuildResult}, name::RevNameBuf, parse::{ParseFromMessage, SplitFromMessage}, - wire::{AsBytes, ParseError, TruncationError, U16}, + wire::{AsBytes, ParseError, U16}, Message, }; @@ -82,11 +82,11 @@ where fn build_into_message( &self, mut builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + ) -> BuildResult { self.qname.build_into_message(builder.delegate())?; builder.append_bytes(self.qtype.as_bytes())?; builder.append_bytes(self.qclass.as_bytes())?; - Ok(()) + Ok(builder.commit()) } } diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 8fa3f30f4..5e380ae0f 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -3,7 +3,7 @@ use core::{borrow::Borrow, ops::Deref}; use super::{ - build::{self, BuildIntoMessage}, + build::{self, BuildIntoMessage, BuildResult}, name::RevNameBuf, parse::{ParseFromMessage, SplitFromMessage}, wire::{ @@ -113,7 +113,7 @@ where fn build_into_message( &self, mut builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + ) -> BuildResult { self.rname.build_into_message(builder.delegate())?; builder.append_bytes(self.rtype.as_bytes())?; builder.append_bytes(self.rclass.as_bytes())?; @@ -129,8 +129,7 @@ where builder.appended_mut()[offset..offset + 2] .copy_from_slice(&size.to_be_bytes()); - builder.commit(); - Ok(()) + Ok(builder.commit()) } } @@ -372,10 +371,7 @@ impl<'a> ParseRecordData<'a> for &'a UnparsedRecordData { //--- Building into DNS messages impl BuildIntoMessage for UnparsedRecordData { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.0.build_into_message(builder) } } diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 0f295ec5d..456da881c 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -13,12 +13,9 @@ use std::net::Ipv4Addr; use domain_macros::*; use crate::new_base::{ - build::{self, BuildIntoMessage}, + build::{self, BuildIntoMessage, BuildResult}, parse::{ParseFromMessage, SplitFromMessage}, - wire::{ - AsBytes, ParseBytes, ParseError, SplitBytes, TruncationError, U16, - U32, - }, + wire::{AsBytes, ParseBytes, ParseError, SplitBytes, U16, U32}, CharStr, Message, Serial, }; @@ -88,10 +85,7 @@ impl fmt::Display for A { //--- Building into DNS messages impl BuildIntoMessage for A { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.as_bytes().build_into_message(builder) } } @@ -132,10 +126,7 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ns { //--- Building into DNS messages impl BuildIntoMessage for Ns { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.name.build_into_message(builder) } } @@ -176,10 +167,7 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for CName { //--- Building into DNS messages impl BuildIntoMessage for CName { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.name.build_into_message(builder) } } @@ -254,7 +242,7 @@ impl BuildIntoMessage for Soa { fn build_into_message( &self, mut builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + ) -> BuildResult { self.mname.build_into_message(builder.delegate())?; self.rname.build_into_message(builder.delegate())?; builder.append_bytes(self.serial.as_bytes())?; @@ -262,8 +250,7 @@ impl BuildIntoMessage for Soa { builder.append_bytes(self.retry.as_bytes())?; builder.append_bytes(self.expire.as_bytes())?; builder.append_bytes(self.minimum.as_bytes())?; - builder.commit(); - Ok(()) + Ok(builder.commit()) } } @@ -314,10 +301,7 @@ impl fmt::Debug for Wks { //--- Building into DNS messages impl BuildIntoMessage for Wks { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.as_bytes().build_into_message(builder) } } @@ -358,10 +342,7 @@ impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ptr { //--- Building into DNS messages impl BuildIntoMessage for Ptr { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.name.build_into_message(builder) } } @@ -399,11 +380,10 @@ impl BuildIntoMessage for HInfo<'_> { fn build_into_message( &self, mut builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + ) -> BuildResult { self.cpu.build_into_message(builder.delegate())?; self.os.build_into_message(builder.delegate())?; - builder.commit(); - Ok(()) + Ok(builder.commit()) } } @@ -454,11 +434,10 @@ impl BuildIntoMessage for Mx { fn build_into_message( &self, mut builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + ) -> BuildResult { builder.append_bytes(self.preference.as_bytes())?; self.exchange.build_into_message(builder.delegate())?; - builder.commit(); - Ok(()) + Ok(builder.commit()) } } @@ -508,10 +487,7 @@ impl<'a> ParseFromMessage<'a> for &'a Txt { //--- Building into DNS messages impl BuildIntoMessage for Txt { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.content.build_into_message(builder) } } diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index 5c2ccce29..43327b50c 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -8,7 +8,7 @@ use domain_macros::*; use crate::{ new_base::{ - build::{self, BuildIntoMessage, TruncationError}, + build::{self, BuildIntoMessage, BuildResult}, wire::{ParseError, SplitBytes}, }, new_edns::EdnsOption, @@ -48,10 +48,7 @@ impl fmt::Debug for Opt { //--- Building into DNS messages impl BuildIntoMessage for Opt { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.contents.build_into_message(builder) } } diff --git a/src/new_rdata/ipv6.rs b/src/new_rdata/ipv6.rs index 788a1ca97..f91ae5e7b 100644 --- a/src/new_rdata/ipv6.rs +++ b/src/new_rdata/ipv6.rs @@ -11,8 +11,8 @@ use std::net::Ipv6Addr; use domain_macros::*; use crate::new_base::{ - build::{self, BuildIntoMessage}, - wire::{AsBytes, TruncationError}, + build::{self, BuildIntoMessage, BuildResult}, + wire::AsBytes, }; //----------- Aaaa ----------------------------------------------------------- @@ -81,10 +81,7 @@ impl fmt::Display for Aaaa { //--- Building into DNS messages impl BuildIntoMessage for Aaaa { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { self.as_bytes().build_into_message(builder) } } diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index ac1804ea8..e4b94a538 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -3,7 +3,7 @@ use domain_macros::*; use crate::new_base::{ - build::{self, BuildIntoMessage}, + build::{self, BuildIntoMessage, BuildResult}, parse::{ParseFromMessage, SplitFromMessage}, wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, Message, ParseRecordData, RType, @@ -132,10 +132,7 @@ where //--- Building record data impl BuildIntoMessage for RecordData<'_, N> { - fn build_into_message( - &self, - builder: build::Builder<'_>, - ) -> Result<(), TruncationError> { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { match self { Self::A(r) => r.build_into_message(builder), Self::Ns(r) => r.build_into_message(builder), From 6caee28ebc672d9e9e26705b2f98fc28af39b0f5 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 9 Jan 2025 13:03:49 +0100 Subject: [PATCH 080/167] [new_base/build] Define 'MessageBuilder' --- src/new_base/build/builder.rs | 16 +- src/new_base/build/message.rs | 243 +++++++++++++++++++++++++++++ src/new_base/build/mod.rs | 3 + src/new_base/record.rs | 13 +- src/new_base/wire/size_prefixed.rs | 61 +++++++- 5 files changed, 321 insertions(+), 15 deletions(-) create mode 100644 src/new_base/build/message.rs diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 8a0bbba33..274bbd7c6 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -133,7 +133,12 @@ impl<'b> Builder<'b> { } /// The appended but uncommitted contents of the message, mutably. - pub fn appended_mut(&mut self) -> &mut [u8] { + /// + /// # Safety + /// + /// The caller must not modify any compressed names among these bytes. + /// This can invalidate name compression state. + pub unsafe fn appended_mut(&mut self) -> &mut [u8] { // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. let range = self.commit..self.context.size; unsafe { &mut (*self.message.as_ptr()).contents[range] } @@ -175,6 +180,15 @@ impl<'b> Builder<'b> { .expect("'message' represents a valid 'Message'") } + /// A pointer to the message, including any uncommitted contents. + /// + /// The first `commit` bytes of the message contents (also provided by + /// [`Self::committed()`]) are immutably borrowed for the lifetime `'b`. + /// The remainder of the message is initialized and borrowed by `self`. + pub fn cur_message_ptr(&self) -> NonNull { + self.cur_message().into() + } + /// The builder context. pub fn context(&self) -> &BuilderContext { &*self.context diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs new file mode 100644 index 000000000..fd8fa3b45 --- /dev/null +++ b/src/new_base/build/message.rs @@ -0,0 +1,243 @@ +//! Building whole DNS messages. + +//----------- MessageBuilder ------------------------------------------------- + +use crate::new_base::{ + wire::TruncationError, Header, Message, Question, Record, +}; + +use super::{BuildIntoMessage, Builder, BuilderContext}; + +/// A builder for a whole DNS message. +/// +/// This is subtly different from a regular [`Builder`] -- it does not allow +/// for commits and so can always modify the entire message. It has methods +/// for adding entire questions and records to the message. +pub struct MessageBuilder<'b> { + /// The underlying [`Builder`]. + /// + /// Its commit point is always 0. + inner: Builder<'b>, +} + +//--- Initialization + +impl<'b> MessageBuilder<'b> { + /// Construct a [`MessageBuilder`] from raw parts. + /// + /// # Safety + /// + /// - `message` and `context` are paired together. + pub unsafe fn from_raw_parts( + message: &'b mut Message, + context: &'b mut BuilderContext, + ) -> Self { + // SAFETY: since 'commit' is 0, no part of the message is immutably + // borrowed; it is thus sound to represent as a mutable borrow. + let inner = + unsafe { Builder::from_raw_parts(message.into(), context, 0) }; + Self { inner } + } + + /// Initialize an empty [`MessageBuilder`]. + /// + /// The message header is left uninitialized. use [`Self::header_mut()`] + /// to initialize it. + /// + /// # Panics + /// + /// Panics if the buffer is less than 12 bytes long (which is the minimum + /// possible size for a DNS message). + pub fn new( + buffer: &'b mut [u8], + context: &'b mut BuilderContext, + ) -> Self { + let inner = Builder::new(buffer, context); + Self { inner } + } +} + +//--- Inspection + +impl<'b> MessageBuilder<'b> { + /// The message header. + /// + /// The header can be modified by the builder, and so is only available + /// for a short lifetime. Note that it implements [`Copy`]. + pub fn header(&self) -> &Header { + self.inner.header() + } + + /// Mutable access to the message header. + pub fn header_mut(&mut self) -> &mut Header { + self.inner.header_mut() + } + + /// Uninitialized space in the message buffer. + /// + /// This can be filled manually, then marked as initialized using + /// [`Self::mark_appended()`]. + pub fn uninitialized(&mut self) -> &mut [u8] { + self.inner.uninitialized() + } + + /// The message built thus far. + pub fn message(&self) -> &Message { + self.inner.cur_message() + } + + /// The message built thus far, mutably. + /// + /// # Safety + /// + /// The caller must not modify any compressed names among these bytes. + /// This can invalidate name compression state. + pub unsafe fn message_mut(&mut self) -> &mut Message { + // SAFETY: Since no bytes are committed, and the rest of the message + // is borrowed mutably for 'self', we can use a mutable reference. + unsafe { self.inner.cur_message_ptr().as_mut() } + } + + /// The builder context. + pub fn context(&self) -> &BuilderContext { + self.inner.context() + } + + /// Decompose this builder into raw parts. + /// + /// This returns the message buffer and the context for this builder. The + /// two are linked, and the builder can be recomposed with + /// [`Self::raw_from_parts()`]. + pub fn into_raw_parts(self) -> (&'b mut Message, &'b mut BuilderContext) { + let (mut message, context, _commit) = self.inner.into_raw_parts(); + // SAFETY: As per 'Builder::into_raw_parts()', the message is borrowed + // mutably for the lifetime 'b. Since the commit point is 0, there is + // no immutably-borrowed content in the message, so it can be turned + // into a regular reference. + (unsafe { message.as_mut() }, context) + } +} + +//--- Interaction + +impl MessageBuilder<'_> { + /// Mark bytes in the buffer as initialized. + /// + /// The given number of bytes from the beginning of + /// [`Self::uninitialized()`] will be marked as initialized, and will be + /// treated as appended content in the buffer. + /// + /// # Panics + /// + /// Panics if the uninitialized buffer is smaller than the given number of + /// initialized bytes. + pub fn mark_appended(&mut self, amount: usize) { + self.inner.mark_appended(amount) + } + + /// Limit the total message size. + /// + /// The message will not be allowed to exceed the given size, in bytes. + /// Only the message header and contents are counted; the enclosing UDP + /// or TCP packet size is not considered. If the message already exceeds + /// this size, a [`TruncationError`] is returned. + /// + /// This size will apply to all builders for this message (including those + /// that delegated to `self`). It will not be automatically revoked if + /// message building fails. + /// + /// # Panics + /// + /// Panics if the given size is less than 12 bytes. + pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { + self.inner.limit_to(size) + } + + /// Append a question. + /// + /// # Panics + /// + /// Panics if the message contains any records (as questions must come + /// before all records). + pub fn append_question( + &mut self, + question: &Question, + ) -> Result<(), TruncationError> + where + N: BuildIntoMessage, + { + // Ensure there are no records present. + let header = self.header(); + let records = header.counts.answers + + header.counts.authorities + + header.counts.additional; + assert_eq!(records, 0); + + question.build_into_message(self.inner.delegate())?; + + self.header_mut().counts.questions += 1; + Ok(()) + } + + /// Append an answer record. + /// + /// # Panics + /// + /// Panics if the message contains any authority or additional records. + pub fn append_answer( + &mut self, + record: &Record, + ) -> Result<(), TruncationError> + where + N: BuildIntoMessage, + D: BuildIntoMessage, + { + // Ensure there are no authority or additional records present. + let header = self.header(); + let records = header.counts.authorities + header.counts.additional; + assert_eq!(records, 0); + + record.build_into_message(self.inner.delegate())?; + + self.header_mut().counts.answers += 1; + Ok(()) + } + + /// Append an authority record. + /// + /// # Panics + /// + /// Panics if the message contains any additional records. + pub fn append_authority( + &mut self, + record: &Record, + ) -> Result<(), TruncationError> + where + N: BuildIntoMessage, + D: BuildIntoMessage, + { + // Ensure there are no additional records present. + let header = self.header(); + let records = header.counts.additional; + assert_eq!(records, 0); + + record.build_into_message(self.inner.delegate())?; + + self.header_mut().counts.authorities += 1; + Ok(()) + } + + /// Append an additional record. + pub fn append_additional( + &mut self, + record: &Record, + ) -> Result<(), TruncationError> + where + N: BuildIntoMessage, + D: BuildIntoMessage, + { + record.build_into_message(self.inner.delegate())?; + self.header_mut().counts.additional += 1; + Ok(()) + } +} diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 86e21ddfc..80b2b0942 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -3,6 +3,9 @@ mod builder; pub use builder::{Builder, BuilderContext}; +mod message; +pub use message::MessageBuilder; + use super::wire::TruncationError; //----------- Message-aware building traits ---------------------------------- diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 5e380ae0f..742a66977 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -118,17 +118,8 @@ where builder.append_bytes(self.rtype.as_bytes())?; builder.append_bytes(self.rclass.as_bytes())?; builder.append_bytes(self.ttl.as_bytes())?; - - // The offset of the record data size. - let offset = builder.appended().len(); - builder.append_bytes(&0u16.to_be_bytes())?; - self.rdata.build_into_message(builder.delegate())?; - let size = builder.appended().len() - 2 - offset; - let size = - u16::try_from(size).expect("the record data never exceeds 64KiB"); - builder.appended_mut()[offset..offset + 2] - .copy_from_slice(&size.to_be_bytes()); - + SizePrefixed::new(&self.rdata) + .build_into_message(builder.delegate())?; Ok(builder.commit()) } } diff --git a/src/new_base/wire/size_prefixed.rs b/src/new_base/wire/size_prefixed.rs index 5e4fc217e..751a57395 100644 --- a/src/new_base/wire/size_prefixed.rs +++ b/src/new_base/wire/size_prefixed.rs @@ -5,6 +5,12 @@ use core::{ ops::{Deref, DerefMut}, }; +use crate::new_base::{ + build::{self, BuildIntoMessage, BuildResult}, + parse::{ParseFromMessage, SplitFromMessage}, + Message, +}; + use super::{ AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, SplitBytes, SplitBytesByRef, TruncationError, U16, @@ -102,6 +108,37 @@ impl AsMut for SizePrefixed { } } +//--- Parsing from DNS messages + +impl<'b, T: ParseFromMessage<'b>> ParseFromMessage<'b> for SizePrefixed { + fn parse_from_message( + message: &'b Message, + start: usize, + ) -> Result { + let (&size, rest) = <&U16>::split_from_message(message, start)?; + if rest + size.get() as usize != message.contents.len() { + return Err(ParseError); + } + T::parse_from_message(message, rest).map(Self::new) + } +} + +impl<'b, T: ParseFromMessage<'b>> SplitFromMessage<'b> for SizePrefixed { + fn split_from_message( + message: &'b Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let (&size, rest) = <&U16>::split_from_message(message, start)?; + let (start, rest) = (rest, rest + size.get() as usize); + if rest > message.contents.len() { + return Err(ParseError); + } + let message = message.slice_to(rest); + let data = T::parse_from_message(message, start)?; + Ok((Self::new(data), rest)) + } +} + //--- Parsing from bytes impl<'b, T: ParseBytes<'b>> ParseBytes<'b> for SizePrefixed { @@ -110,8 +147,7 @@ impl<'b, T: ParseBytes<'b>> ParseBytes<'b> for SizePrefixed { if rest.len() != size.get() as usize { return Err(ParseError); } - let data = T::parse_bytes(bytes)?; - Ok(Self { size, data }) + T::parse_bytes(bytes).map(Self::new) } } @@ -123,7 +159,7 @@ impl<'b, T: ParseBytes<'b>> SplitBytes<'b> for SizePrefixed { } let (data, rest) = rest.split_at(size.get() as usize); let data = T::parse_bytes(data)?; - Ok((Self { size, data }, rest)) + Ok((Self::new(data), rest)) } } @@ -209,6 +245,25 @@ unsafe impl SplitBytesByRef for SizePrefixed { } } +//--- Building into DNS messages + +impl BuildIntoMessage for SizePrefixed { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> BuildResult { + assert_eq!(builder.appended(), &[] as &[u8]); + builder.append_bytes(&0u16.to_be_bytes())?; + self.data.build_into_message(builder.delegate())?; + let size = builder.appended().len() - 2; + let size = u16::try_from(size).expect("the data never exceeds 64KiB"); + // SAFETY: A 'U16' is being modified, not a domain name. + let size_buf = unsafe { &mut builder.appended_mut()[0..2] }; + size_buf.copy_from_slice(&size.to_be_bytes()); + Ok(builder.commit()) + } +} + //--- Building into byte strings impl BuildBytes for SizePrefixed { From 95d0fe898b8600b0ec10abd7ab24a77aa71be134 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 9 Jan 2025 17:54:13 +0100 Subject: [PATCH 081/167] [new_base/message] Add 'as_bytes_mut()' and 'slice_to_mut()' --- src/new_base/message.rs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/new_base/message.rs b/src/new_base/message.rs index ab1aa903c..734a11051 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -19,6 +19,27 @@ pub struct Message { pub contents: [u8], } +//--- Inspection + +impl Message { + /// Represent this as a mutable byte sequence. + /// + /// Given `&mut self`, it is already possible to individually modify the + /// message header and contents; since neither has invalid instances, it + /// is safe to represent the entire object as mutable bytes. + pub fn as_bytes_mut(&mut self) -> &mut [u8] { + // SAFETY: + // - 'Self' has no padding bytes and no interior mutability. + // - Its size in memory is exactly 'size_of_val(self)'. + unsafe { + core::slice::from_raw_parts_mut( + self as *mut Self as *mut u8, + core::mem::size_of_val(self), + ) + } + } +} + //--- Interaction impl Message { @@ -30,6 +51,15 @@ impl Message { Self::parse_bytes_by_ref(bytes) .expect("A 12-or-more byte string is a valid 'Message'") } + + /// Truncate the contents of this message to the given size, mutably. + /// + /// The returned value will have a `contents` field of the given size. + pub fn slice_to_mut(&mut self, size: usize) -> &mut Self { + let bytes = &mut self.as_bytes_mut()[..12 + size]; + Self::parse_bytes_by_mut(bytes) + .expect("A 12-or-more byte string is a valid 'Message'") + } } //----------- Header --------------------------------------------------------- From 3f14ccae8b875947e0d1ad6f5e7aa6243162bce4 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 13 Jan 2025 12:07:12 +0100 Subject: [PATCH 082/167] [new_base/name] Add 'UnparsedName' --- src/new_base/name/mod.rs | 3 + src/new_base/name/reversed.rs | 60 +++++++------ src/new_base/name/unparsed.rs | 162 ++++++++++++++++++++++++++++++++++ 3 files changed, 196 insertions(+), 29 deletions(-) create mode 100644 src/new_base/name/unparsed.rs diff --git a/src/new_base/name/mod.rs b/src/new_base/name/mod.rs index 9270f4d5c..ca9f3e581 100644 --- a/src/new_base/name/mod.rs +++ b/src/new_base/name/mod.rs @@ -19,3 +19,6 @@ pub use label::{Label, LabelIter}; mod reversed; pub use reversed::{RevName, RevNameBuf}; + +mod unparsed; +pub use unparsed::UnparsedName; diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index aa58b24cf..d022352be 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -1,4 +1,4 @@ -//! Reversed DNS names. +//! Reversed domain names. use core::{ borrow::Borrow, @@ -312,6 +312,7 @@ impl<'a> ParseFromMessage<'a> for RevNameBuf { } // Keep going, from the referenced position. + let start = start.checked_sub(12).ok_or(ParseError)?; let bytes = contents.get(start..).ok_or(ParseError)?; (pointer, _) = parse_segment(bytes, &mut buffer)?; old_start = start; @@ -331,38 +332,39 @@ fn parse_segment<'a>( buffer: &mut RevNameBuf, ) -> Result<(Option, &'a [u8]), ParseError> { loop { - let (&length, rest) = bytes.split_first().ok_or(ParseError)?; - if length == 0 { - // Found the root, stop. - buffer.prepend(&[0u8]); - return Ok((None, rest)); - } else if length < 64 { - // This looks like a regular label. - - if rest.len() < length as usize { - // The input doesn't contain the whole label. - return Err(ParseError); - } else if buffer.offset < 2 + length { - // The output name would exceed 254 bytes (this isn't - // the root label, so it can't fill the 255th byte). - return Err(ParseError); + match bytes { + &[0, ref rest @ ..] => { + // Found the root, stop. + buffer.prepend(&[0u8]); + return Ok((None, rest)); } - let (label, rest) = bytes.split_at(1 + length as usize); - buffer.prepend(label); - bytes = rest; - } else if length >= 0xC0 { - // This looks like a compression pointer. + &[l, ..] if l < 64 => { + // This looks like a regular label. + + if bytes.len() < 1 + l as usize { + // The input doesn't contain the whole label. + return Err(ParseError); + } else if buffer.offset < 2 + l { + // The output name would exceed 254 bytes (this isn't + // the root label, so it can't fill the 255th byte). + return Err(ParseError); + } + + let (label, rest) = bytes.split_at(1 + l as usize); + buffer.prepend(label); + bytes = rest; + } - let (&extra, rest) = rest.split_first().ok_or(ParseError)?; - let pointer = u16::from_be_bytes([length, extra]); + &[hi, lo, ref rest @ ..] if hi >= 0xC0 => { + let pointer = u16::from_be_bytes([hi, lo]); - // NOTE: We don't verify the pointer here, that's left to - // the caller (since they have to actually use it). - return Ok((Some(pointer & 0x3FFF), rest)); - } else { - // This is an invalid or deprecated label type. - return Err(ParseError); + // NOTE: We don't verify the pointer here, that's left to + // the caller (since they have to actually use it). + return Ok((Some(pointer & 0x3FFF), rest)); + } + + _ => return Err(ParseError), } } } diff --git a/src/new_base/name/unparsed.rs b/src/new_base/name/unparsed.rs new file mode 100644 index 000000000..e437ddd83 --- /dev/null +++ b/src/new_base/name/unparsed.rs @@ -0,0 +1,162 @@ +//! Unparsed domain names. + +use domain_macros::*; + +use crate::new_base::{ + parse::{ParseFromMessage, SplitFromMessage}, + wire::ParseError, + Message, +}; + +//----------- UnparsedName --------------------------------------------------- + +/// An unparsed domain name in a DNS message. +/// +/// Within a DNS message, domain names are stored in conventional order (from +/// innermost to the root label), and may end with a compression pointer. An +/// [`UnparsedName`] represents this incomplete domain name, exactly as stored +/// in a message. +#[derive(AsBytes)] +#[repr(transparent)] +pub struct UnparsedName([u8]); + +//--- Constants + +impl UnparsedName { + /// The maximum size of an unparsed domain name. + /// + /// A domain name can be 255 bytes at most, but an unparsed domain name + /// could replace the last byte (representing the root label) with a + /// compression pointer to it. Since compression pointers are 2 bytes, + /// the total size becomes 256 bytes. + pub const MAX_SIZE: usize = 256; + + /// The root name. + pub const ROOT: &'static Self = { + // SAFETY: A root label is the shortest valid name. + unsafe { Self::from_bytes_unchecked(&[0u8]) } + }; +} + +//--- Construction + +impl UnparsedName { + /// Assume a byte string is a valid [`UnparsedName`]. + /// + /// # Safety + /// + /// The byte string must contain any number of encoded labels, ending with + /// a root label or a compression pointer, as long as the size of the + /// whole string is 256 bytes or less. + pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Self { + // SAFETY: 'UnparsedName' is 'repr(transparent)' to '[u8]', so casting + // a '[u8]' into an 'UnparsedName' is sound. + core::mem::transmute(bytes) + } +} + +//--- Inspection + +impl UnparsedName { + /// The size of this name in the wire format. + #[allow(clippy::len_without_is_empty)] + pub const fn len(&self) -> usize { + self.0.len() + } + + /// Whether this is the root label. + pub const fn is_root(&self) -> bool { + self.0.len() == 1 + } + + /// A byte representation of the [`UnparsedName`]. + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } +} + +//--- Parsing from DNS messages + +impl<'a> SplitFromMessage<'a> for &'a UnparsedName { + fn split_from_message( + message: &'a Message, + start: usize, + ) -> Result<(Self, usize), ParseError> { + let bytes = message.contents.get(start..).ok_or(ParseError)?; + let mut offset = 0; + let offset = loop { + match &bytes[offset..] { + // This is the root label. + &[0, ..] => break offset + 1, + + // This looks like a regular label. + &[l, ref rest @ ..] if (1..64).contains(&l) => { + let length = l as usize; + + if rest.len() < length || offset + 2 + length > 255 { + // The name is incomplete or too big. + return Err(ParseError); + } + + offset += 1 + length; + } + + // This is a compression pointer. + &[hi, lo, ..] if hi >= 0xC0 => { + let ptr = u16::from_be_bytes([hi, lo]); + if usize::from(ptr - 0xC000) >= start { + return Err(ParseError); + } + break offset + 2; + } + + _ => return Err(ParseError), + } + }; + + let bytes = &bytes[..offset]; + let rest = start + offset; + Ok((unsafe { UnparsedName::from_bytes_unchecked(bytes) }, rest)) + } +} + +impl<'a> ParseFromMessage<'a> for &'a UnparsedName { + fn parse_from_message( + message: &'a Message, + start: usize, + ) -> Result { + let bytes = message.contents.get(start..).ok_or(ParseError)?; + let mut offset = 0; + loop { + match &bytes[offset..] { + // This is the root label. + &[0] => break, + + // This looks like a regular label. + &[l, ref rest @ ..] if (1..64).contains(&l) => { + let length = l as usize; + + if rest.len() < length || offset + 2 + length > 255 { + // The name is incomplete or too big. + return Err(ParseError); + } + + offset += 1 + length; + } + + // This is a compression pointer. + &[hi, lo] if hi >= 0xC0 => { + let ptr = u16::from_be_bytes([hi, lo]); + if usize::from(ptr - 0xC000) >= start { + return Err(ParseError); + } + break; + } + + _ => return Err(ParseError), + } + } + + Ok(unsafe { UnparsedName::from_bytes_unchecked(bytes) }) + } +} From 2c9594efde7310026af45905a0adc2663e5102b5 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 13 Jan 2025 14:21:01 +0100 Subject: [PATCH 083/167] [new_base/message] Impl 'as_array()' for 'SectionCounts' --- src/new_base/message.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/new_base/message.rs b/src/new_base/message.rs index 734a11051..d900b4a1c 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -295,6 +295,22 @@ pub struct SectionCounts { pub additional: U16, } +//--- Interaction + +impl SectionCounts { + /// Represent these counts as an array. + pub fn as_array(&self) -> &[U16; 4] { + // SAFETY: 'SectionCounts' has the same layout as '[U16; 4]'. + unsafe { core::mem::transmute(self) } + } + + /// Represent these counts as a mutable array. + pub fn as_array_mut(&mut self) -> &mut [U16; 4] { + // SAFETY: 'SectionCounts' has the same layout as '[U16; 4]'. + unsafe { core::mem::transmute(self) } + } +} + //--- Formatting impl fmt::Display for SectionCounts { From a01a7f24092cb8fc625a39b73e2813115434c88a Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 13 Jan 2025 14:21:16 +0100 Subject: [PATCH 084/167] [new_base/build] Define 'RecordBuilder' --- src/new_base/build/message.rs | 114 ++++++++++++++--------- src/new_base/build/mod.rs | 3 + src/new_base/build/record.rs | 166 ++++++++++++++++++++++++++++++++++ 3 files changed, 242 insertions(+), 41 deletions(-) create mode 100644 src/new_base/build/record.rs diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index fd8fa3b45..a79928726 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -1,12 +1,13 @@ //! Building whole DNS messages. -//----------- MessageBuilder ------------------------------------------------- - use crate::new_base::{ - wire::TruncationError, Header, Message, Question, Record, + wire::TruncationError, Header, Message, Question, RClass, RType, Record, + TTL, }; -use super::{BuildIntoMessage, Builder, BuilderContext}; +use super::{BuildIntoMessage, Builder, BuilderContext, RecordBuilder}; + +//----------- MessageBuilder ------------------------------------------------- /// A builder for a whole DNS message. /// @@ -73,14 +74,6 @@ impl<'b> MessageBuilder<'b> { self.inner.header_mut() } - /// Uninitialized space in the message buffer. - /// - /// This can be filled manually, then marked as initialized using - /// [`Self::mark_appended()`]. - pub fn uninitialized(&mut self) -> &mut [u8] { - self.inner.uninitialized() - } - /// The message built thus far. pub fn message(&self) -> &Message { self.inner.cur_message() @@ -107,7 +100,7 @@ impl<'b> MessageBuilder<'b> { /// /// This returns the message buffer and the context for this builder. The /// two are linked, and the builder can be recomposed with - /// [`Self::raw_from_parts()`]. + /// [`Self::from_raw_parts()`]. pub fn into_raw_parts(self) -> (&'b mut Message, &'b mut BuilderContext) { let (mut message, context, _commit) = self.inner.into_raw_parts(); // SAFETY: As per 'Builder::into_raw_parts()', the message is borrowed @@ -121,20 +114,6 @@ impl<'b> MessageBuilder<'b> { //--- Interaction impl MessageBuilder<'_> { - /// Mark bytes in the buffer as initialized. - /// - /// The given number of bytes from the beginning of - /// [`Self::uninitialized()`] will be marked as initialized, and will be - /// treated as appended content in the buffer. - /// - /// # Panics - /// - /// Panics if the uninitialized buffer is smaller than the given number of - /// initialized bytes. - pub fn mark_appended(&mut self, amount: usize) { - self.inner.mark_appended(amount) - } - /// Limit the total message size. /// /// The message will not be allowed to exceed the given size, in bytes. @@ -167,18 +146,36 @@ impl MessageBuilder<'_> { N: BuildIntoMessage, { // Ensure there are no records present. - let header = self.header(); - let records = header.counts.answers - + header.counts.authorities - + header.counts.additional; - assert_eq!(records, 0); + assert_eq!(self.header().counts.as_array()[1..], [0, 0, 0]); question.build_into_message(self.inner.delegate())?; - self.header_mut().counts.questions += 1; Ok(()) } + /// Build an arbitrary record. + /// + /// The record will be added to the specified section (1, 2, or 3, i.e. + /// answers, authorities, and additional records respectively). There + /// must not be any existing records in sections after this one. + pub fn build_record( + &mut self, + rname: impl BuildIntoMessage, + rtype: RType, + rclass: RClass, + ttl: TTL, + section: u8, + ) -> Result, TruncationError> { + RecordBuilder::new( + self.inner.delegate(), + rname, + rtype, + rclass, + ttl, + section, + ) + } + /// Append an answer record. /// /// # Panics @@ -193,16 +190,28 @@ impl MessageBuilder<'_> { D: BuildIntoMessage, { // Ensure there are no authority or additional records present. - let header = self.header(); - let records = header.counts.authorities + header.counts.additional; - assert_eq!(records, 0); + assert_eq!(self.header().counts.as_array()[2..], [0, 0]); record.build_into_message(self.inner.delegate())?; - self.header_mut().counts.answers += 1; Ok(()) } + /// Build an answer record. + /// + /// # Panics + /// + /// Panics if the message contains any authority or additional records. + pub fn build_answer( + &mut self, + rname: impl BuildIntoMessage, + rtype: RType, + rclass: RClass, + ttl: TTL, + ) -> Result, TruncationError> { + self.build_record(rname, rtype, rclass, ttl, 1) + } + /// Append an authority record. /// /// # Panics @@ -217,16 +226,28 @@ impl MessageBuilder<'_> { D: BuildIntoMessage, { // Ensure there are no additional records present. - let header = self.header(); - let records = header.counts.additional; - assert_eq!(records, 0); + assert_eq!(self.header().counts.as_array()[3..], [0]); record.build_into_message(self.inner.delegate())?; - self.header_mut().counts.authorities += 1; Ok(()) } + /// Build an authority record. + /// + /// # Panics + /// + /// Panics if the message contains any additional records. + pub fn build_authority( + &mut self, + rname: impl BuildIntoMessage, + rtype: RType, + rclass: RClass, + ttl: TTL, + ) -> Result, TruncationError> { + self.build_record(rname, rtype, rclass, ttl, 2) + } + /// Append an additional record. pub fn append_additional( &mut self, @@ -240,4 +261,15 @@ impl MessageBuilder<'_> { self.header_mut().counts.additional += 1; Ok(()) } + + /// Build an additional record. + pub fn build_additional( + &mut self, + rname: impl BuildIntoMessage, + rtype: RType, + rclass: RClass, + ttl: TTL, + ) -> Result, TruncationError> { + self.build_record(rname, rtype, rclass, ttl, 3) + } } diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 80b2b0942..7b1598ede 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -6,6 +6,9 @@ pub use builder::{Builder, BuilderContext}; mod message; pub use message::MessageBuilder; +mod record; +pub use record::RecordBuilder; + use super::wire::TruncationError; //----------- Message-aware building traits ---------------------------------- diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs new file mode 100644 index 000000000..873628bda --- /dev/null +++ b/src/new_base/build/record.rs @@ -0,0 +1,166 @@ +//! Building DNS records. + +use crate::new_base::{ + name::RevName, + wire::{AsBytes, TruncationError}, + Header, Message, RClass, RType, TTL, +}; + +use super::{BuildCommitted, BuildIntoMessage, Builder}; + +//----------- RecordBuilder -------------------------------------------------- + +/// A builder for a DNS record. +/// +/// This is used to incrementally build the data for a DNS record. It can be +/// constructed using [`MessageBuilder::build_answer()`] etc. +/// +/// [`MessageBuilder::build_answer()`]: super::MessageBuilder::build_answer() +pub struct RecordBuilder<'b> { + /// The underlying [`Builder`]. + /// + /// Its commit point lies at the beginning of the record. + inner: Builder<'b>, + + /// The position of the record data. + /// + /// This is an offset from the message contents. + start: usize, + + /// The section the record is a part of. + /// + /// The appropriate section count will be incremented on completion. + section: u8, +} + +//--- Initialization + +impl<'b> RecordBuilder<'b> { + /// Construct a [`RecordBuilder`] from raw parts. + /// + /// # Safety + /// + /// - `builder`, `start`, and `section` are paired together. + pub unsafe fn from_raw_parts( + builder: Builder<'b>, + start: usize, + section: u8, + ) -> Self { + Self { + inner: builder, + start, + section, + } + } + + /// Initialize a new [`RecordBuilder`]. + /// + /// A new record with the given name, type, and class will be created. + /// The returned builder can be used to add data for the record. + /// + /// The count for the specified section (1, 2, or 3, i.e. answers, + /// authorities, and additional records respectively) will be incremented + /// when the builder finishes successfully. + pub fn new( + mut builder: Builder<'b>, + rname: impl BuildIntoMessage, + rtype: RType, + rclass: RClass, + ttl: TTL, + section: u8, + ) -> Result { + debug_assert_eq!(builder.appended(), &[] as &[u8]); + debug_assert!((1..4).contains(§ion)); + + assert!(builder + .header() + .counts + .as_array() + .iter() + .skip(1 + section as usize) + .all(|&c| c == 0)); + + // Build the record header. + rname.build_into_message(builder.delegate())?; + builder.append_bytes(rtype.as_bytes())?; + builder.append_bytes(rclass.as_bytes())?; + builder.append_bytes(ttl.as_bytes())?; + let start = builder.appended().len(); + + // Set up the builder. + Ok(Self { + inner: builder, + start, + section, + }) + } +} + +//--- Inspection + +impl<'b> RecordBuilder<'b> { + /// The message header. + pub fn header(&self) -> &Header { + self.inner.header() + } + + /// The message without this record. + pub fn message(&self) -> &Message { + self.inner.message() + } + + /// The record data appended thus far. + pub fn data(&self) -> &[u8] { + &self.inner.appended()[self.start..] + } + + /// Decompose this builder into raw parts. + /// + /// This returns the underlying builder, the offset of the record data in + /// the record, and the section number for this record (1, 2, or 3). The + /// builder can be recomposed with [`Self::from_raw_parts()`]. + pub fn into_raw_parts(self) -> (Builder<'b>, usize, u8) { + (self.inner, self.start, self.section) + } +} + +//--- Interaction + +impl RecordBuilder<'_> { + /// Finish the record. + /// + /// The respective section count will be incremented. The builder will be + /// consumed and the record will be committed. + pub fn finish(mut self) -> BuildCommitted { + // Increment the appropriate section count. + self.inner.header_mut().counts.as_array_mut() + [self.section as usize] += 1; + + self.inner.commit() + } + + /// Delegate to a new builder. + /// + /// Any content committed by the builder will be added as record data. + pub fn delegate(&mut self) -> Builder<'_> { + self.inner.delegate() + } + + /// Append some bytes. + /// + /// No name compression will be performed. + pub fn append_bytes( + &mut self, + bytes: &[u8], + ) -> Result<(), TruncationError> { + self.inner.append_bytes(bytes) + } + + /// Compress and append a domain name. + pub fn append_name( + &mut self, + name: &RevName, + ) -> Result<(), TruncationError> { + self.inner.append_name(name) + } +} From 7ef9b5b63c122caccb0d0fd20453cde976950fe9 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 13 Jan 2025 14:24:23 +0100 Subject: [PATCH 085/167] [new_base/name] Accept Clippy simplifications --- src/new_base/name/reversed.rs | 8 ++++---- src/new_base/name/unparsed.rs | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index d022352be..f33451f3b 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -332,14 +332,14 @@ fn parse_segment<'a>( buffer: &mut RevNameBuf, ) -> Result<(Option, &'a [u8]), ParseError> { loop { - match bytes { - &[0, ref rest @ ..] => { + match *bytes { + [0, ref rest @ ..] => { // Found the root, stop. buffer.prepend(&[0u8]); return Ok((None, rest)); } - &[l, ..] if l < 64 => { + [l, ..] if l < 64 => { // This looks like a regular label. if bytes.len() < 1 + l as usize { @@ -356,7 +356,7 @@ fn parse_segment<'a>( bytes = rest; } - &[hi, lo, ref rest @ ..] if hi >= 0xC0 => { + [hi, lo, ref rest @ ..] if hi >= 0xC0 => { let pointer = u16::from_be_bytes([hi, lo]); // NOTE: We don't verify the pointer here, that's left to diff --git a/src/new_base/name/unparsed.rs b/src/new_base/name/unparsed.rs index e437ddd83..828c92229 100644 --- a/src/new_base/name/unparsed.rs +++ b/src/new_base/name/unparsed.rs @@ -85,12 +85,12 @@ impl<'a> SplitFromMessage<'a> for &'a UnparsedName { let bytes = message.contents.get(start..).ok_or(ParseError)?; let mut offset = 0; let offset = loop { - match &bytes[offset..] { + match bytes[offset..] { // This is the root label. - &[0, ..] => break offset + 1, + [0, ..] => break offset + 1, // This looks like a regular label. - &[l, ref rest @ ..] if (1..64).contains(&l) => { + [l, ref rest @ ..] if (1..64).contains(&l) => { let length = l as usize; if rest.len() < length || offset + 2 + length > 255 { @@ -102,7 +102,7 @@ impl<'a> SplitFromMessage<'a> for &'a UnparsedName { } // This is a compression pointer. - &[hi, lo, ..] if hi >= 0xC0 => { + [hi, lo, ..] if hi >= 0xC0 => { let ptr = u16::from_be_bytes([hi, lo]); if usize::from(ptr - 0xC000) >= start { return Err(ParseError); @@ -128,12 +128,12 @@ impl<'a> ParseFromMessage<'a> for &'a UnparsedName { let bytes = message.contents.get(start..).ok_or(ParseError)?; let mut offset = 0; loop { - match &bytes[offset..] { + match bytes[offset..] { // This is the root label. - &[0] => break, + [0] => break, // This looks like a regular label. - &[l, ref rest @ ..] if (1..64).contains(&l) => { + [l, ref rest @ ..] if (1..64).contains(&l) => { let length = l as usize; if rest.len() < length || offset + 2 + length > 255 { @@ -145,7 +145,7 @@ impl<'a> ParseFromMessage<'a> for &'a UnparsedName { } // This is a compression pointer. - &[hi, lo] if hi >= 0xC0 => { + [hi, lo] if hi >= 0xC0 => { let ptr = u16::from_be_bytes([hi, lo]); if usize::from(ptr - 0xC000) >= start { return Err(ParseError); From bfa8c5cb37494043497a8f59b3e2b03304360b48 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 13 Jan 2025 17:00:47 +0100 Subject: [PATCH 086/167] [new_base/build/record] Track record data size --- src/new_base/build/record.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index 873628bda..aac0857c3 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -85,6 +85,7 @@ impl<'b> RecordBuilder<'b> { builder.append_bytes(rtype.as_bytes())?; builder.append_bytes(rclass.as_bytes())?; builder.append_bytes(ttl.as_bytes())?; + builder.append_bytes(&0u16.to_be_bytes())?; let start = builder.appended().len(); // Set up the builder. @@ -136,6 +137,15 @@ impl RecordBuilder<'_> { self.inner.header_mut().counts.as_array_mut() [self.section as usize] += 1; + // Set the record data length. + let size = self.inner.appended().len() - self.start; + let size = u16::try_from(size) + .expect("Record data must be smaller than 64KiB"); + // SAFETY: The record data size is not part of a compressed name. + let appended = unsafe { self.inner.appended_mut() }; + appended[self.start - 2..self.start] + .copy_from_slice(&size.to_be_bytes()); + self.inner.commit() } From 07cf2deed00d5811b0c1b964d73da0d49727b07b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 16 Jan 2025 12:38:44 +0100 Subject: [PATCH 087/167] [new_base/parse] Remove 'ParseMessage' etc. These interfaces need to be redesigned to be more specific to important use cases. --- src/new_base/parse/message.rs | 49 ----------- src/new_base/parse/mod.rs | 9 -- src/new_base/parse/question.rs | 148 --------------------------------- src/new_base/parse/record.rs | 148 --------------------------------- 4 files changed, 354 deletions(-) delete mode 100644 src/new_base/parse/message.rs delete mode 100644 src/new_base/parse/question.rs delete mode 100644 src/new_base/parse/record.rs diff --git a/src/new_base/parse/message.rs b/src/new_base/parse/message.rs deleted file mode 100644 index 1c964588a..000000000 --- a/src/new_base/parse/message.rs +++ /dev/null @@ -1,49 +0,0 @@ -//! Parsing DNS messages. - -use core::ops::ControlFlow; - -use crate::new_base::{Header, UnparsedQuestion, UnparsedRecord}; - -/// A type that can be constructed by parsing a DNS message. -pub trait ParseMessage<'a>: Sized { - /// The type of visitors for incrementally building the output. - type Visitor: VisitMessagePart<'a>; - - /// The type of errors from converting a visitor into [`Self`]. - // TODO: Just use 'Visitor::Error'? - type Error; - - /// Construct a visitor, providing the message header. - fn make_visitor(header: &'a Header) - -> Result; - - /// Convert a visitor back to this type. - fn from_visitor(visitor: Self::Visitor) -> Result; -} - -/// A type that can visit the components of a DNS message. -pub trait VisitMessagePart<'a> { - /// The type of errors produced by visits. - type Error; - - /// Visit a component of the message. - fn visit( - &mut self, - component: MessagePart<'a>, - ) -> Result, Self::Error>; -} - -/// A component of a DNS message. -pub enum MessagePart<'a> { - /// A question. - Question(&'a UnparsedQuestion), - - /// An answer record. - Answer(&'a UnparsedRecord<'a>), - - /// An authority record. - Authority(&'a UnparsedRecord<'a>), - - /// An additional record. - Additional(&'a UnparsedRecord<'a>), -} diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 7e5a08d7f..e6d47f4f0 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -1,14 +1,5 @@ //! Parsing DNS messages from the wire format. -mod message; -pub use message::{MessagePart, ParseMessage, VisitMessagePart}; - -mod question; -pub use question::{ParseQuestion, ParseQuestions, VisitQuestion}; - -mod record; -pub use record::{ParseRecord, ParseRecords, VisitRecord}; - pub use super::wire::ParseError; use super::{ diff --git a/src/new_base/parse/question.rs b/src/new_base/parse/question.rs deleted file mode 100644 index 784cadc09..000000000 --- a/src/new_base/parse/question.rs +++ /dev/null @@ -1,148 +0,0 @@ -//! Parsing DNS questions. - -use core::{convert::Infallible, ops::ControlFlow}; - -#[cfg(feature = "std")] -use std::boxed::Box; -#[cfg(feature = "std")] -use std::vec::Vec; - -use crate::new_base::UnparsedQuestion; - -//----------- Trait definitions ---------------------------------------------- - -/// A type that can be constructed by parsing exactly one DNS question. -pub trait ParseQuestion: Sized { - /// The type of parse errors. - // TODO: Remove entirely? - type Error; - - /// Parse the given DNS question. - fn parse_question( - question: &UnparsedQuestion, - ) -> Result, Self::Error>; -} - -/// A type that can be constructed by parsing zero or more DNS questions. -pub trait ParseQuestions: Sized { - /// The type of visitors for incrementally building the output. - type Visitor: Default + VisitQuestion; - - /// The type of errors from converting a visitor into [`Self`]. - // TODO: Just use 'Visitor::Error'? Or remove entirely? - type Error; - - /// Convert a visitor back to this type. - fn from_visitor(visitor: Self::Visitor) -> Result; -} - -/// A type that can visit DNS questions. -pub trait VisitQuestion { - /// The type of errors produced by visits. - type Error; - - /// Visit a question. - fn visit_question( - &mut self, - question: &UnparsedQuestion, - ) -> Result, Self::Error>; -} - -//----------- Trait implementations ------------------------------------------ - -impl ParseQuestion for UnparsedQuestion { - type Error = Infallible; - - fn parse_question( - question: &UnparsedQuestion, - ) -> Result, Self::Error> { - Ok(ControlFlow::Break(question.clone())) - } -} - -//--- Impls for 'Option' - -impl ParseQuestion for Option { - type Error = T::Error; - - fn parse_question( - question: &UnparsedQuestion, - ) -> Result, Self::Error> { - Ok(match T::parse_question(question)? { - ControlFlow::Break(elem) => ControlFlow::Break(Some(elem)), - ControlFlow::Continue(()) => ControlFlow::Continue(()), - }) - } -} - -impl ParseQuestions for Option { - type Visitor = Option; - type Error = Infallible; - - fn from_visitor(visitor: Self::Visitor) -> Result { - Ok(visitor) - } -} - -impl VisitQuestion for Option { - type Error = T::Error; - - fn visit_question( - &mut self, - question: &UnparsedQuestion, - ) -> Result, Self::Error> { - if self.is_some() { - return Ok(ControlFlow::Continue(())); - } - - Ok(match T::parse_question(question)? { - ControlFlow::Break(elem) => { - *self = Some(elem); - ControlFlow::Break(()) - } - ControlFlow::Continue(()) => ControlFlow::Continue(()), - }) - } -} - -//--- Impls for 'Vec' - -#[cfg(feature = "std")] -impl ParseQuestions for Vec { - type Visitor = Vec; - type Error = Infallible; - - fn from_visitor(visitor: Self::Visitor) -> Result { - Ok(visitor) - } -} - -#[cfg(feature = "std")] -impl VisitQuestion for Vec { - type Error = T::Error; - - fn visit_question( - &mut self, - question: &UnparsedQuestion, - ) -> Result, Self::Error> { - Ok(match T::parse_question(question)? { - ControlFlow::Break(elem) => { - self.push(elem); - ControlFlow::Break(()) - } - ControlFlow::Continue(()) => ControlFlow::Continue(()), - }) - } -} - -//--- Impls for 'Box<[T]>' - -#[cfg(feature = "std")] -impl ParseQuestions for Box<[T]> { - type Visitor = Vec; - type Error = Infallible; - - fn from_visitor(visitor: Self::Visitor) -> Result { - Ok(visitor.into_boxed_slice()) - } -} diff --git a/src/new_base/parse/record.rs b/src/new_base/parse/record.rs deleted file mode 100644 index 75e98a36a..000000000 --- a/src/new_base/parse/record.rs +++ /dev/null @@ -1,148 +0,0 @@ -//! Parsing DNS records. - -use core::{convert::Infallible, ops::ControlFlow}; - -#[cfg(feature = "std")] -use std::boxed::Box; -#[cfg(feature = "std")] -use std::vec::Vec; - -use crate::new_base::UnparsedRecord; - -//----------- Trait definitions ---------------------------------------------- - -/// A type that can be constructed by parsing exactly one DNS record. -pub trait ParseRecord<'a>: Sized { - /// The type of parse errors. - // TODO: Remove entirely? - type Error; - - /// Parse the given DNS record. - fn parse_record( - record: &UnparsedRecord<'a>, - ) -> Result, Self::Error>; -} - -/// A type that can be constructed by parsing zero or more DNS records. -pub trait ParseRecords<'a>: Sized { - /// The type of visitors for incrementally building the output. - type Visitor: Default + VisitRecord<'a>; - - /// The type of errors from converting a visitor into [`Self`]. - // TODO: Just use 'Visitor::Error'? Or remove entirely? - type Error; - - /// Convert a visitor back to this type. - fn from_visitor(visitor: Self::Visitor) -> Result; -} - -/// A type that can visit DNS records. -pub trait VisitRecord<'a> { - /// The type of errors produced by visits. - type Error; - - /// Visit a record. - fn visit_record( - &mut self, - record: &UnparsedRecord<'a>, - ) -> Result, Self::Error>; -} - -//----------- Trait implementations ------------------------------------------ - -impl<'a> ParseRecord<'a> for UnparsedRecord<'a> { - type Error = Infallible; - - fn parse_record( - record: &UnparsedRecord<'a>, - ) -> Result, Self::Error> { - Ok(ControlFlow::Break(record.clone())) - } -} - -//--- Impls for 'Option' - -impl<'a, T: ParseRecord<'a>> ParseRecord<'a> for Option { - type Error = T::Error; - - fn parse_record( - record: &UnparsedRecord<'a>, - ) -> Result, Self::Error> { - Ok(match T::parse_record(record)? { - ControlFlow::Break(elem) => ControlFlow::Break(Some(elem)), - ControlFlow::Continue(()) => ControlFlow::Continue(()), - }) - } -} - -impl<'a, T: ParseRecord<'a>> ParseRecords<'a> for Option { - type Visitor = Option; - type Error = Infallible; - - fn from_visitor(visitor: Self::Visitor) -> Result { - Ok(visitor) - } -} - -impl<'a, T: ParseRecord<'a>> VisitRecord<'a> for Option { - type Error = T::Error; - - fn visit_record( - &mut self, - record: &UnparsedRecord<'a>, - ) -> Result, Self::Error> { - if self.is_some() { - return Ok(ControlFlow::Continue(())); - } - - Ok(match T::parse_record(record)? { - ControlFlow::Break(elem) => { - *self = Some(elem); - ControlFlow::Break(()) - } - ControlFlow::Continue(()) => ControlFlow::Continue(()), - }) - } -} - -//--- Impls for 'Vec' - -#[cfg(feature = "std")] -impl<'a, T: ParseRecord<'a>> ParseRecords<'a> for Vec { - type Visitor = Vec; - type Error = Infallible; - - fn from_visitor(visitor: Self::Visitor) -> Result { - Ok(visitor) - } -} - -#[cfg(feature = "std")] -impl<'a, T: ParseRecord<'a>> VisitRecord<'a> for Vec { - type Error = T::Error; - - fn visit_record( - &mut self, - record: &UnparsedRecord<'a>, - ) -> Result, Self::Error> { - Ok(match T::parse_record(record)? { - ControlFlow::Break(elem) => { - self.push(elem); - ControlFlow::Break(()) - } - ControlFlow::Continue(()) => ControlFlow::Continue(()), - }) - } -} - -//--- Impls for 'Box<[T]>' - -#[cfg(feature = "std")] -impl<'a, T: ParseRecord<'a>> ParseRecords<'a> for Box<[T]> { - type Visitor = Vec; - type Error = Infallible; - - fn from_visitor(visitor: Self::Visitor) -> Result { - Ok(visitor.into_boxed_slice()) - } -} From 477252047ef16a64c25c1f5426d02db863eae53f Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 16 Jan 2025 12:47:08 +0100 Subject: [PATCH 088/167] [new_base/build] Document 'BuildCommitted' thoroughly --- src/new_base/build/mod.rs | 54 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 7b1598ede..38ab501fe 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -38,13 +38,63 @@ impl BuildIntoMessage for [u8] { //----------- BuildResult ---------------------------------------------------- /// The result of building into a DNS message. +/// +/// This is used in [`BuildIntoMessage::build_into_message()`]. pub type BuildResult = Result; //----------- BuildCommitted ------------------------------------------------- /// The output of [`Builder::commit()`]. /// -/// This is a stub type to remind users to call [`Builder::commit()`] in all -/// success paths of building functions. +/// This is a simple marker type, produced by [`Builder::commit()`]. Certain +/// trait methods (e.g. [`BuildIntoMessage::build_into_message()`]) require it +/// in the return type, as a way to remind users to commit their builders. +/// +/// # Examples +/// +/// If `build_into_message()` simply returned a unit type, an example impl may +/// look like: +/// +/// ```compile_fail +/// # use domain::new_base::name::RevName; +/// # use domain::new_base::build::{BuildIntoMessage, Builder, BuildResult}; +/// # use domain::new_base::wire::AsBytes; +/// +/// struct Foo<'a>(&'a RevName, u8); +/// +/// impl BuildIntoMessage for Foo<'_> { +/// fn build_into_message( +/// &self, +/// mut builder: Builder<'_>, +/// ) -> BuildResult { +/// builder.append_name(self.0)?; +/// builder.append_bytes(self.1.as_bytes()); +/// Ok(()) +/// } +/// } +/// ``` +/// +/// This code is incorrect: since the appended content is not committed, the +/// builder will remove it when it is dropped (at the end of the function), +/// and so nothing gets written. Instead, users have to write: +/// +/// ``` +/// # use domain::new_base::name::RevName; +/// # use domain::new_base::build::{BuildIntoMessage, Builder, BuildResult}; +/// # use domain::new_base::wire::AsBytes; +/// +/// struct Foo<'a>(&'a RevName, u8); +/// +/// impl BuildIntoMessage for Foo<'_> { +/// fn build_into_message( +/// &self, +/// mut builder: Builder<'_>, +/// ) -> BuildResult { +/// builder.append_name(self.0)?; +/// builder.append_bytes(self.1.as_bytes()); +/// Ok(builder.commit()) +/// } +/// } +/// ``` #[derive(Debug)] pub struct BuildCommitted; From 554bb71462da71d1098b92659d84314e090dd8ca Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 20 Jan 2025 14:55:10 +0100 Subject: [PATCH 089/167] [new_base/build/builder] Rewrite with a lot of documentation --- src/new_base/build/builder.rs | 430 ++++++++++++++++++++++++---------- src/new_base/message.rs | 18 ++ 2 files changed, 329 insertions(+), 119 deletions(-) diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 274bbd7c6..16444349c 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -8,7 +8,7 @@ use core::{ use crate::new_base::{ name::RevName, - wire::{AsBytes, BuildBytes, ParseBytesByRef, TruncationError}, + wire::{BuildBytes, ParseBytesByRef, TruncationError}, Header, Message, }; @@ -16,7 +16,86 @@ use super::BuildCommitted; //----------- Builder -------------------------------------------------------- -/// A DNS message builder. +/// A DNS wire format serializer. +/// +/// This can be used to write arbitrary bytes and (compressed) domain names to +/// a buffer containing a DNS message. It is a low-level interface, providing +/// the foundations for high-level builder types. +/// +/// In order to build a regular DNS message, users would typically look to +/// [`MessageBuilder`](super::MessageBuilder). This offers the high-level +/// interface (with methods to append questions and records) that most users +/// need. +/// +/// # Committing and Delegation +/// +/// [`Builder`] provides an "atomic" interface: if a function fails while +/// building a DNS message using a [`Builder`], any partial content added by +/// the [`Builder`] will be reverted. The content of a [`Builder`] is only +/// confirmed when [`Builder::commit()`] is called. +/// +/// It is useful to first describe what "building functions" look like. While +/// they may take additional arguments, their signatures are usually: +/// +/// ```no_run +/// # use domain::new_base::build::{Builder, BuildResult}; +/// +/// fn foo(mut builder: Builder<'_>) -> BuildResult { +/// // Append to the message using 'builder'. +/// +/// // Commit all appended content and return successfully. +/// Ok(builder.commit()) +/// } +/// ``` +/// +/// Note that the builder is taken by value; if an error occurs, and the +/// function returns early, `builder` will be dropped, and its drop code will +/// revert all uncommitted changes. However, if building is successful, the +/// appended content is committed, and so will not be reverted. +/// +/// If `foo` were to call another function with the same signature, it would +/// need to create a new [`Builder`] to pass in by value. This [`Builder`] +/// should refer to the same message buffer, but should have not report any +/// uncommitted content (so that only the content added by the called function +/// will be reverted on failure). For this, we have [`delegate()`]. +/// +/// [`delegate()`]: Self::delegate() +/// +/// For example: +/// +/// ``` +/// # use domain::new_base::build::{Builder, BuildResult, BuilderContext}; +/// +/// /// A build function with the conventional type signature. +/// fn foo(mut builder: Builder<'_>) -> BuildResult { +/// // Content added by the parent builder is considered committed. +/// assert_eq!(builder.committed(), b"hi! "); +/// +/// // Append some content to the builder. +/// builder.append_bytes(b"foo!")?; +/// +/// // Try appending a very long string, which can't fit. +/// builder.append_bytes(b"helloworldthisiswaytoobig")?; +/// +/// Ok(builder.commit()) +/// } +/// +/// // Construct a builder for a particular buffer. +/// let mut buffer = [0u8; 20]; +/// let mut context = BuilderContext::default(); +/// let mut builder = Builder::new(&mut buffer, &mut context); +/// +/// // Try appending some content to the builder. +/// builder.append_bytes(b"hi! ").unwrap(); +/// assert_eq!(builder.appended(), b"hi! "); +/// +/// // Try calling 'foo' -- note that it will fail. +/// // Note that we delegated the builder. +/// foo(builder.delegate()).unwrap_err(); +/// +/// // No partial content was written. +/// assert_eq!(builder.appended(), b"hi! "); +/// ``` pub struct Builder<'b> { /// The message being built. /// @@ -41,22 +120,74 @@ pub struct Builder<'b> { commit: usize, } -//--- Initialization - +/// # Initialization +/// +/// In order to begin building a DNS message: +/// +/// ``` +/// # use domain::new_base::build::{Builder, BuilderContext}; +/// +/// // Allocate a slice of 'u8's somewhere. +/// let mut buffer = [0u8; 20]; +/// +/// // Obtain a builder context. +/// // +/// // The value doesn't matter, it will be overwritten. +/// let mut context = BuilderContext::default(); +/// +/// // Construct the actual 'Builder'. +/// let builder = Builder::new(&mut buffer, &mut context); +/// +/// assert!(builder.committed().is_empty()); +/// assert!(builder.appended().is_empty()); +/// ``` impl<'b> Builder<'b> { + /// Create a [`Builder`] for a new, empty DNS message. + /// + /// The message header is left uninitialized. Use [`Self::header_mut()`] + /// to initialize it. The message contents are completely empty. + /// + /// The provided builder context will be overwritten with a default state. + /// + /// # Panics + /// + /// Panics if the buffer is less than 12 bytes long (which is the minimum + /// possible size for a DNS message). + pub fn new( + buffer: &'b mut [u8], + context: &'b mut BuilderContext, + ) -> Self { + let message = Message::parse_bytes_by_mut(buffer) + .expect("The buffure must be at least 12 bytes in size"); + context.size = 0; + + // SAFETY: 'message' and 'context' are now consistent. + unsafe { Self::from_raw_parts(message.into(), context, 0) } + } + /// Construct a [`Builder`] from raw parts. /// + /// The provided components must originate from [`into_raw_parts()`], and + /// none of the components can be modified since they were extracted. + /// + /// [`into_raw_parts()`]: Self::into_raw_parts() + /// + /// This method is useful when overcoming limitations in lifetimes or + /// borrow checking, or when a builder has to be constructed from another + /// with specific characteristics. + /// /// # Safety /// + /// The expression `from_raw_parts(message, context, commit)` is sound if + /// and only if all of the following conditions are satisfied: + /// /// - `message` is a valid reference for the lifetime `'b`. /// - `message.header` is mutably borrowed for `'b`. /// - `message.contents[..commit]` is immutably borrowed for `'b`. /// - `message.contents[commit..]` is mutably borrowed for `'b`. /// - /// - `message` and `context` are paired together. - /// - /// - `commit` is at most `context.size()`, which is at most - /// `context.max_size()`. + /// - `message` and `context` originate from the same builder. + /// - `commit <= context.size() <= message.contents.len()`. pub unsafe fn from_raw_parts( message: NonNull, context: &'b mut BuilderContext, @@ -69,53 +200,82 @@ impl<'b> Builder<'b> { commit, } } - - /// Initialize an empty [`Builder`]. - /// - /// The message header is left uninitialized. Use [`Self::header_mut()`] - /// to initialize it. - /// - /// # Panics - /// - /// Panics if the buffer is less than 12 bytes long (which is the minimum - /// possible size for a DNS message). - pub fn new( - buffer: &'b mut [u8], - context: &'b mut BuilderContext, - ) -> Self { - assert!(buffer.len() >= 12); - let message = Message::parse_bytes_by_mut(buffer) - .expect("A 'Message' can fit in 12 bytes"); - context.size = 0; - context.max_size = message.contents.len(); - - // SAFETY: 'message' and 'context' are now consistent. - unsafe { Self::from_raw_parts(message.into(), context, 0) } - } } -//--- Inspection - +/// # Inspection +/// +/// A [`Builder`] references a message buffer to write into. That buffer is +/// broken down into the following segments: +/// +/// ```text +/// name | position +/// --------------+--------- +/// header | +/// committed | 0 .. commit +/// appended | commit .. size +/// uninitialized | size .. limit +/// inaccessible | limit .. +/// ``` +/// +/// The DNS message header can be modified at any time. It is made available +/// through [`header()`] and [`header_mut()`]. In general, it is inadvisable +/// to change the section counts arbitrarily (although it will not cause +/// undefined behaviour). +/// +/// [`header()`]: Self::header() +/// [`header_mut()`]: Self::header_mut() +/// +/// The committed content of the builder is immutable, and is available to +/// reference, through [`committed()`], for the lifetime `'b`. +/// +/// [`committed()`]: Self::committed() +/// +/// The appended content of the builder is made available via [`appended()`]. +/// It is content that has been added by this builder, but that has not yet +/// been committed. When the [`Builder`] is dropped, this content is removed +/// (it becomes uninitialized). Appended content can be modified, but any +/// compressed names within it have to be handled with great care; they can +/// only be modified by removing them entirely (by rewinding the builder, +/// using [`rewind()`]) and building them again. When compressed names are +/// guaranteed to not be modified, [`appended_mut()`] can be used. +/// +/// [`appended()`]: Self::appended() +/// [`rewind()`]: Self::rewind() +/// [`appended_mut()`]: Self::appended_mut() +/// +/// The uninitialized space in the builder will be written to when appending +/// new content. It can be accessed directly, in case that is more efficient +/// for building, using [`uninitialized()`]. [`mark_appended()`] can be used +/// to specify how many bytes were initialized. +/// +/// [`uninitialized()`]: Self::uninitialized() +/// [`mark_appended()`]: Self::mark_appended() +/// +/// The inaccessible space of a builder cannot be written to. While it exists +/// in the underlying message buffer, it has been made inaccessible so that +/// the built message fits within certain size constraints. A message's size +/// can be limited using [`limit_to()`], but this only applies to the current +/// builder (and its delegates); parent builders are unaffected by it. +/// +/// [`limit_to()`]: Self::limit_to() impl<'b> Builder<'b> { - /// The message header. - /// - /// The header can be modified by the builder, and so is only available - /// for a short lifetime. Note that it implements [`Copy`]. + /// The header of the DNS message. pub fn header(&self) -> &Header { // SAFETY: 'message.header' is mutably borrowed by 'self'. unsafe { &(*self.message.as_ptr()).header } } - /// Mutable access to the message header. + /// The header of the DNS message, mutably. + /// + /// It is possible to modify the section counts arbitrarily through this + /// method; while doing so cannot cause undefined behaviour, it is not + /// recommended. pub fn header_mut(&mut self) -> &mut Header { // SAFETY: 'message.header' is mutably borrowed by 'self'. unsafe { &mut (*self.message.as_ptr()).header } } /// Committed message contents. - /// - /// The message contents are available for the lifetime `'b`; the builder - /// cannot be used to modify them since they have been committed. pub fn committed(&self) -> &'b [u8] { // SAFETY: 'message.contents[..commit]' is immutably borrowed by // 'self'. @@ -146,38 +306,31 @@ impl<'b> Builder<'b> { /// Uninitialized space in the message buffer. /// - /// This can be filled manually, then marked as initialized using - /// [`Self::mark_appended()`]. + /// When the first `n` bytes of the returned buffer are initialized, and + /// should be treated as appended content in the message, call + /// [`self.mark_appended(n)`](Self::mark_appended()). pub fn uninitialized(&mut self) -> &mut [u8] { // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. - let range = self.context.size..self.context.max_size; - unsafe { &mut (*self.message.as_ptr()).contents[range] } + unsafe { &mut (*self.message.as_ptr()).contents[self.context.size..] } } /// The message with all committed contents. /// /// The header of the message can be modified by the builder, so the /// returned reference has a short lifetime. The message contents can be - /// borrowed for a longer lifetime -- see [`Self::committed()`]. + /// borrowed for a longer lifetime -- see [`committed()`]. The message + /// does not include content that has been appended but not committed. + /// + /// [`committed()`]: Self::committed() pub fn message(&self) -> &Message { // SAFETY: All of 'message' can be immutably borrowed by 'self'. - let message = unsafe { &*self.message.as_ptr() }; - let message = &message.as_bytes()[..12 + self.commit]; - Message::parse_bytes_by_ref(message) - .expect("'message' represents a valid 'Message'") + unsafe { self.message.as_ref() }.slice_to(self.commit) } /// The message including any uncommitted contents. - /// - /// The header of the message can be modified by the builder, so the - /// returned reference has a short lifetime. The message contents can be - /// borrowed for a longer lifetime -- see [`Self::committed()`]. pub fn cur_message(&self) -> &Message { // SAFETY: All of 'message' can be immutably borrowed by 'self'. - let message = unsafe { &*self.message.as_ptr() }; - let message = &message.as_bytes()[..12 + self.context.size]; - Message::parse_bytes_by_ref(message) - .expect("'message' represents a valid 'Message'") + unsafe { self.message.as_ref() }.slice_to(self.context.size) } /// A pointer to the message, including any uncommitted contents. @@ -185,8 +338,11 @@ impl<'b> Builder<'b> { /// The first `commit` bytes of the message contents (also provided by /// [`Self::committed()`]) are immutably borrowed for the lifetime `'b`. /// The remainder of the message is initialized and borrowed by `self`. - pub fn cur_message_ptr(&self) -> NonNull { - self.cur_message().into() + pub fn cur_message_ptr(&mut self) -> NonNull { + let message = self.message.as_ptr(); + let size = self.context.size; + let message = unsafe { Message::ptr_slice_to(message, size) }; + unsafe { NonNull::new_unchecked(message) } } /// The builder context. @@ -194,6 +350,29 @@ impl<'b> Builder<'b> { &*self.context } + /// The start point of this builder. + /// + /// This is the offset into the message contents at which this builder was + /// initialized. The content before this point has been committed and is + /// immutable. The builder can be rewound up to this point. + pub fn start(&self) -> usize { + self.commit + } + + /// The size limit of this builder. + /// + /// This is the maximum size the message contents can grow to; beyond it, + /// [`TruncationError`]s will occur. The limit can be tightened using + /// [`limit_to()`](Self::limit_to()). + pub fn max_size(&self) -> usize { + // SAFETY: 'Message' ends with a slice DST, and so references to it + // hold the length of that slice; we can cast it to another slice type + // and the pointer representation is unchanged. By using a slice type + // of ZST elements, aliasing is impossible, and it can be dereferenced + // safely. + unsafe { &*(self.message.as_ptr() as *mut [()]) }.len() + } + /// Decompose this builder into raw parts. /// /// This returns three components: @@ -221,24 +400,74 @@ impl<'b> Builder<'b> { } } -//--- Interaction - +/// # Interaction +/// +/// There are several ways to build up a DNS message using a [`Builder`]. +/// +/// When directly adding content, use [`append_bytes()`] or [`append_name()`]. +/// The former will add the bytes as-is, while the latter will compress domain +/// names. +/// +/// [`append_bytes()`]: Self::append_bytes() +/// [`append_name()`]: Self::append_name() +/// +/// When delegating to another builder method, use [`delegate()`]. This will +/// construct a new [`Builder`] that borrows from the current one. When the +/// method returns, the content it has committed will be registered as content +/// appended (but not committed) by the outer builder. If the method fails, +/// any content it tried to add will be removed automatically, and the outer +/// builder will be left unaffected. +/// +/// [`delegate()`]: Self::delegate() +/// +/// After all data is appended, call [`commit()`]. This will return a marker +/// type, [`BuildCommitted`], that may need to be returned to the caller. +/// +/// [`commit()`]: Self::commit() +/// +/// Some lower-level building methods are also available in the interest of +/// efficiency. Use [`append_with()`] if the amount of data to be written is +/// known upfront; it takes a closure to fill that space in the buffer. The +/// most general and efficient technique is to write into [`uninitialized()`] +/// and to mark the number of initialized bytes using [`mark_appended()`]. +/// +/// [`append_with()`]: Self::append_with() +/// [`uninitialized()`]: Self::uninitialized() +/// [`mark_appended()`]: Self::mark_appended() impl Builder<'_> { - /// Rewind the builder, removing all committed content. + /// Rewind the builder, removing all uncommitted content. pub fn rewind(&mut self) { self.context.size = self.commit; } - /// Commit all appended content. + /// Commit the changes made by this builder. /// /// For convenience, a unit type [`BuildCommitted`] is returned; it is /// used as the return type of build functions to remind users to call /// this method on success paths. - pub fn commit(&mut self) -> BuildCommitted { + pub fn commit(mut self) -> BuildCommitted { + // Update 'commit' so that the drop glue is a no-op. self.commit = self.context.size; BuildCommitted } + /// Limit this builder to the given size. + /// + /// This builder, and all its delegates, will not allow the message + /// contents (i.e. excluding the 12-byte message header) to exceed the + /// specified size in bytes. If the message has already crossed that + /// limit, a [`TruncationError`] is returned. + pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { + if self.context.size <= size { + let message = self.message.as_ptr(); + let message = unsafe { Message::ptr_slice_to(message, size) }; + self.message = unsafe { NonNull::new_unchecked(message) }; + Ok(()) + } else { + Err(TruncationError) + } + } + /// Mark bytes in the buffer as initialized. /// /// The given number of bytes from the beginning of @@ -250,7 +479,7 @@ impl Builder<'_> { /// Panics if the uninitialized buffer is smaller than the given number of /// initialized bytes. pub fn mark_appended(&mut self, amount: usize) { - assert!(self.context.max_size - self.context.size >= amount); + assert!(self.max_size() - self.context.size >= amount); self.context.size += amount; } @@ -265,30 +494,6 @@ impl Builder<'_> { } } - /// Limit the total message size. - /// - /// The message will not be allowed to exceed the given size, in bytes. - /// Only the message header and contents are counted; the enclosing UDP - /// or TCP packet size is not considered. If the message already exceeds - /// this size, a [`TruncationError`] is returned. - /// - /// This size will apply to all builders for this message (including those - /// that delegated to `self`). It will not be automatically revoked if - /// message building fails. - /// - /// # Panics - /// - /// Panics if the given size is less than 12 bytes. - pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { - assert!(size >= 12); - if self.context.size <= size - 12 { - self.context.max_size = size - 12; - Ok(()) - } else { - Err(TruncationError) - } - } - /// Append data of a known size using a closure. /// /// All the requested bytes must be initialized. If not enough free space @@ -336,40 +541,27 @@ impl Drop for Builder<'_> { } } +//--- Send, Sync + +// SAFETY: The parts of the referenced message that can be accessed mutably +// are not accessible by any reference other than `self`. +unsafe impl Send for Builder<'_> {} + +// SAFETY: Only parts of the referenced message that are borrowed immutably +// can be accessed through an immutable reference to `self`. +unsafe impl Sync for Builder<'_> {} + //----------- BuilderContext ------------------------------------------------- /// Context for building a DNS message. -#[derive(Clone, Debug)] +/// +/// This type holds auxiliary information necessary for building DNS messages, +/// e.g. name compression state. To construct it, call [`default()`]. +/// +/// [`default()`]: Self::default() +#[derive(Clone, Debug, Default)] pub struct BuilderContext { // TODO: Name compression. /// The current size of the message contents. size: usize, - - /// The maximum size of the message contents. - max_size: usize, -} - -//--- Inspection - -impl BuilderContext { - /// The size of the message contents. - pub fn size(&self) -> usize { - self.size - } - - /// The maximum size of the message contents. - pub fn max_size(&self) -> usize { - self.max_size - } -} - -//--- Default - -impl Default for BuilderContext { - fn default() -> Self { - Self { - size: 0, - max_size: 65535 - core::mem::size_of::
(), - } - } } diff --git a/src/new_base/message.rs b/src/new_base/message.rs index d900b4a1c..27e4cde88 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -60,6 +60,24 @@ impl Message { Self::parse_bytes_by_mut(bytes) .expect("A 12-or-more byte string is a valid 'Message'") } + + /// Truncate the contents of this message to the given size, by pointer. + /// + /// The returned value will have a `contents` field of the given size. + /// + /// # Safety + /// + /// This method uses `pointer::offset()`: `self` must be "derived from a + /// pointer to some allocated object". There must be at least 12 bytes + /// between `self` and the end of that allocated object. A reference to + /// `Message` will always result in a pointer satisfying this. + pub unsafe fn ptr_slice_to(this: *mut Message, size: usize) -> *mut Self { + let bytes = unsafe { core::ptr::addr_of_mut!((*this).contents) }; + let len = unsafe { &*(bytes as *mut [()]) }.len(); + debug_assert!(size <= len); + core::ptr::slice_from_raw_parts_mut(this.cast::(), size) + as *mut Self + } } //----------- Header --------------------------------------------------------- From ef702e300be3bb7268140107e3269739fd353cfd Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 20 Jan 2025 17:41:26 +0100 Subject: [PATCH 090/167] [new_base/build] Add module documentation --- src/new_base/build/mod.rs | 88 +++++++++++++++++++++++++++++++++++++++ src/new_base/question.rs | 56 +++++++++++++++++++++++++ src/new_base/wire/ints.rs | 5 +++ 3 files changed, 149 insertions(+) diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 38ab501fe..e723ef521 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -1,4 +1,92 @@ //! Building DNS messages in the wire format. +//! +//! The [`wire`](super::wire) module provides basic serialization capability, +//! but it is not specialized to DNS messages. This module provides that +//! specialization within an ergonomic interface. +//! +//! # The High-Level Interface +//! +//! The core of the high-level interface is [`MessageBuilder`]. It provides +//! the most intuitive methods for appending whole questions and records. +//! +//! ``` +//! use domain::new_base::{Header, HeaderFlags, Question, QType, QClass}; +//! use domain::new_base::build::{BuilderContext, MessageBuilder, BuildIntoMessage}; +//! use domain::new_base::name::RevName; +//! use domain::new_base::wire::U16; +//! +//! // Initialize a DNS message builder. +//! let mut buffer = [0u8; 512]; +//! let mut context = BuilderContext::default(); +//! let mut builder = MessageBuilder::new(&mut buffer, &mut context); +//! +//! // Initialize the message header. +//! let header = builder.header_mut(); +//! *builder.header_mut() = Header { +//! // Select a randomized ID here. +//! id: U16::new(1234), +//! // A recursive query for authoritative data. +//! flags: HeaderFlags::default() +//! .query(0) +//! .set_authoritative(true) +//! .request_recursion(true), +//! counts: Default::default(), +//! }; +//! +//! // Add a question for an A record. +//! // TODO: Use a more ergonomic way to make a name. +//! let name = b"\x00\x03org\x07example\x03www"; +//! let name = unsafe { RevName::from_bytes_unchecked(name) }; +//! let question = Question { +//! qname: name, +//! qtype: QType::A, +//! qclass: QClass::IN, +//! }; +//! builder.append_question(&question).unwrap(); +//! +//! // Use the built message. +//! let message = builder.message(); +//! # let _ = message; +//! ``` +//! +//! # The Low-Level Interface +//! +//! [`Builder`] is a powerful low-level interface that can be used to build +//! DNS messages. It implements atomic building and name compression, and is +//! the foundation of [`MessageBuilder`]. +//! +//! The [`Builder`] interface does not know about questions and records; it is +//! only capable of appending simple bytes and compressing domain names. Its +//! access to the message buffer is limited; it can only append, modify, or +//! truncate the message up to a certain point (all data before that point is +//! immutable). Special attention is given to the message header, as it can +//! be modified at any point in the message building process. +//! +//! ``` +//! use domain::new_base::build::{BuilderContext, Builder, BuildIntoMessage}; +//! use domain::new_rdata::A; +//! +//! // Construct a builder for a particular buffer. +//! let mut buffer = [0u8; 20]; +//! let mut context = BuilderContext::default(); +//! let mut builder = Builder::new(&mut buffer, &mut context); +//! +//! // Try appending some raw bytes to the builder. +//! builder.append_bytes(b"hi! ").unwrap(); +//! assert_eq!(builder.appended(), b"hi! "); +//! +//! // Try appending some structured content to the builder. +//! A::from(std::net::Ipv4Addr::new(127, 0, 0, 1)) +//! .build_into_message(builder.delegate()) +//! .unwrap(); +//! assert_eq!(builder.appended(), b"hi! \x7F\x00\x00\x01"); +//! +//! // Finish using the builder. +//! builder.commit(); +//! +//! // Note: the first 12 bytes hold the message header. +//! assert_eq!(&buffer[12..20], b"hi! \x7F\x00\x00\x01"); +//! ``` mod builder; pub use builder::{Builder, BuilderContext}; diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 720d46e14..e4602a4b6 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -115,6 +115,46 @@ pub struct QType { pub code: U16, } +//--- Associated Constants + +impl QType { + const fn new(value: u16) -> Self { + Self { + code: U16::new(value), + } + } + + /// The type of an [`A`](crate::new_rdata::A) record. + pub const A: Self = Self::new(1); + + /// The type of an [`Ns`](crate::new_rdata::Ns) record. + pub const NS: Self = Self::new(2); + + /// The type of a [`CName`](crate::new_rdata::CName) record. + pub const CNAME: Self = Self::new(5); + + /// The type of an [`Soa`](crate::new_rdata::Soa) record. + pub const SOA: Self = Self::new(6); + + /// The type of a [`Wks`](crate::new_rdata::Wks) record. + pub const WKS: Self = Self::new(11); + + /// The type of a [`Ptr`](crate::new_rdata::Ptr) record. + pub const PTR: Self = Self::new(12); + + /// The type of a [`HInfo`](crate::new_rdata::HInfo) record. + pub const HINFO: Self = Self::new(13); + + /// The type of a [`Mx`](crate::new_rdata::Mx) record. + pub const MX: Self = Self::new(15); + + /// The type of a [`Txt`](crate::new_rdata::Txt) record. + pub const TXT: Self = Self::new(16); + + /// The type of an [`Aaaa`](crate::new_rdata::Aaaa) record. + pub const AAAA: Self = Self::new(28); +} + //----------- QClass --------------------------------------------------------- /// The class of a question. @@ -139,3 +179,19 @@ pub struct QClass { /// The class code. pub code: U16, } + +//--- Associated Constants + +impl QClass { + const fn new(value: u16) -> Self { + Self { + code: U16::new(value), + } + } + + /// The Internet class. + pub const IN: Self = Self::new(1); + + /// The CHAOS class. + pub const CH: Self = Self::new(3); +} diff --git a/src/new_base/wire/ints.rs b/src/new_base/wire/ints.rs index 3d11f45e4..15834a55f 100644 --- a/src/new_base/wire/ints.rs +++ b/src/new_base/wire/ints.rs @@ -51,6 +51,11 @@ macro_rules! define_int { pub const fn get(self) -> $base { <$base>::from_be_bytes(self.0) } + + /// Overwrite this value with an integer. + pub const fn set(&mut self, value: $base) { + *self = Self::new(value) + } } impl From<$base> for $name { From 0e7346cead2e6317a6ff4dd63876cd13f0774266 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 20 Jan 2025 17:44:31 +0100 Subject: [PATCH 091/167] [new_base/parse] Add a bit of module documentation --- src/new_base/parse/mod.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index e6d47f4f0..6744c15a6 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -1,4 +1,9 @@ //! Parsing DNS messages from the wire format. +//! +//! This module provides [`ParseFromMessage`] and [`SplitFromMessage`], which +//! are specializations of [`ParseBytes`] and [`SplitBytes`] to DNS messages. +//! When parsing data within a DNS message, these traits allow access to all +//! preceding bytes in the message so that compressed names can be resolved. pub use super::wire::ParseError; From 967d1d5a84eb50becf4a64f3711e8a58ec675070 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 20 Jan 2025 17:45:10 +0100 Subject: [PATCH 092/167] [new_base/parse] Add missing doc links --- src/new_base/parse/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 6744c15a6..cbe949681 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -4,6 +4,9 @@ //! are specializations of [`ParseBytes`] and [`SplitBytes`] to DNS messages. //! When parsing data within a DNS message, these traits allow access to all //! preceding bytes in the message so that compressed names can be resolved. +//! +//! [`ParseBytes`]: super::wire::ParseBytes +//! [`SplitBytes`]: super::wire::SplitBytes pub use super::wire::ParseError; From 4fa09b25db12330c5e02da0a842696cec36b04b6 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 20 Jan 2025 17:51:39 +0100 Subject: [PATCH 093/167] [new_base/wire/ints] Make 'set()' non-const --- src/new_base/wire/ints.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/new_base/wire/ints.rs b/src/new_base/wire/ints.rs index 15834a55f..acd2990e2 100644 --- a/src/new_base/wire/ints.rs +++ b/src/new_base/wire/ints.rs @@ -53,7 +53,8 @@ macro_rules! define_int { } /// Overwrite this value with an integer. - pub const fn set(&mut self, value: $base) { + // TODO: Make 'const' at MSRV 1.83.0. + pub fn set(&mut self, value: $base) { *self = Self::new(value) } } From 9e44e48c61740ae9dc4c92689a2782fe331de06d Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 22 Jan 2025 14:40:26 +0100 Subject: [PATCH 094/167] [new_base/wire/parse] Fix documentation typo See: --- src/new_base/wire/parse.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/new_base/wire/parse.rs b/src/new_base/wire/parse.rs index 3ee5d44a1..2a8da6642 100644 --- a/src/new_base/wire/parse.rs +++ b/src/new_base/wire/parse.rs @@ -89,8 +89,7 @@ impl<'a> SplitBytes<'a> for u8 { /// Deriving [`SplitBytes`] automatically. /// /// [`SplitBytes`] can be derived on `struct`s (not `enum`s or `union`s). All -/// fields except the last must implement [`SplitBytes`], while the last field -/// only needs to implement [`SplitBytes`]. +/// fields must implement [`SplitBytes`]. /// /// Here's a simple example: /// From 9af8ea199cb2435919961338525a06d8e62e44ed Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 22 Jan 2025 14:45:27 +0100 Subject: [PATCH 095/167] [new_edns/cookie] Rename 'CookieRequest' to 'ClientCookie' Also fixes typo in field name 'reversed' ('reserved') of 'Cookie'. See: See: --- src/new_edns/cookie.rs | 30 +++++++++++++++--------------- src/new_edns/mod.rs | 20 ++++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/new_edns/cookie.rs b/src/new_edns/cookie.rs index 36d96a4f4..77c810be5 100644 --- a/src/new_edns/cookie.rs +++ b/src/new_edns/cookie.rs @@ -20,7 +20,7 @@ use crate::new_base::Serial; #[cfg(all(feature = "std", feature = "siphasher"))] use crate::new_base::wire::{AsBytes, TruncationError}; -//----------- CookieRequest -------------------------------------------------- +//----------- ClientCookie --------------------------------------------------- /// A request for a DNS cookie. #[derive( @@ -37,15 +37,15 @@ use crate::new_base::wire::{AsBytes, TruncationError}; SplitBytesByRef, )] #[repr(transparent)] -pub struct CookieRequest { +pub struct ClientCookie { /// The octets of the request. pub octets: [u8; 8], } //--- Construction -impl CookieRequest { - /// Construct a random [`CookieRequest`]. +impl ClientCookie { + /// Construct a random [`ClientCookie`]. #[cfg(feature = "rand")] pub fn random() -> Self { rand::random::<[u8; 8]>().into() @@ -54,7 +54,7 @@ impl CookieRequest { //--- Interaction -impl CookieRequest { +impl ClientCookie { /// Build a [`Cookie`] in response to this request. /// /// A 24-byte version-1 interoperable cookie will be generated and written @@ -101,27 +101,27 @@ impl CookieRequest { //--- Conversion to and from octets -impl From<[u8; 8]> for CookieRequest { +impl From<[u8; 8]> for ClientCookie { fn from(value: [u8; 8]) -> Self { Self { octets: value } } } -impl From for [u8; 8] { - fn from(value: CookieRequest) -> Self { +impl From for [u8; 8] { + fn from(value: ClientCookie) -> Self { value.octets } } //--- Formatting -impl fmt::Debug for CookieRequest { +impl fmt::Debug for ClientCookie { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "CookieRequest({})", self) + write!(f, "ClientCookie({})", self) } } -impl fmt::Display for CookieRequest { +impl fmt::Display for ClientCookie { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{:016X}", u64::from_be_bytes(self.octets)) } @@ -135,14 +135,14 @@ impl fmt::Display for CookieRequest { )] #[repr(C)] pub struct Cookie { - /// The request for this cookie. - request: CookieRequest, + /// The client's request for this cookie. + request: ClientCookie, /// The version number of this cookie. version: u8, /// Reserved bytes in the cookie format. - reversed: [u8; 3], + reserved: [u8; 3], /// When this cookie was made. timestamp: Serial, @@ -155,7 +155,7 @@ pub struct Cookie { impl Cookie { /// The underlying cookie request. - pub fn request(&self) -> &CookieRequest { + pub fn request(&self) -> &ClientCookie { &self.request } diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 152cd5dae..8bc3c55b6 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -21,7 +21,7 @@ use crate::{ //----------- EDNS option modules -------------------------------------------- mod cookie; -pub use cookie::{Cookie, CookieRequest}; +pub use cookie::{ClientCookie, Cookie}; mod ext_err; pub use ext_err::{ExtError, ExtErrorCode}; @@ -212,10 +212,10 @@ impl fmt::Debug for EdnsFlags { #[derive(Debug)] #[non_exhaustive] pub enum EdnsOption<'b> { - /// A request for a DNS cookie. - CookieRequest(&'b CookieRequest), + /// A client's request for a DNS cookie. + ClientCookie(&'b ClientCookie), - /// A DNS cookie. + /// A server-provided DNS cookie. Cookie(&'b Cookie), /// An extended DNS error. @@ -231,7 +231,7 @@ impl EdnsOption<'_> { /// The code for this option. pub fn code(&self) -> OptionCode { match self { - Self::CookieRequest(_) => OptionCode::COOKIE, + Self::ClientCookie(_) => OptionCode::COOKIE, Self::Cookie(_) => OptionCode::COOKIE, Self::ExtError(_) => OptionCode::EXT_ERROR, Self::Unknown(code, _) => *code, @@ -248,8 +248,8 @@ impl<'b> ParseBytes<'b> for EdnsOption<'b> { match code { OptionCode::COOKIE => match data.len() { - 8 => CookieRequest::parse_bytes_by_ref(data) - .map(Self::CookieRequest), + 8 => ClientCookie::parse_bytes_by_ref(data) + .map(Self::ClientCookie), 16..=40 => Cookie::parse_bytes_by_ref(data).map(Self::Cookie), _ => Err(ParseError), }, @@ -273,8 +273,8 @@ impl<'b> SplitBytes<'b> for EdnsOption<'b> { let this = match code { OptionCode::COOKIE => match data.len() { - 8 => <&CookieRequest>::parse_bytes(data) - .map(Self::CookieRequest)?, + 8 => <&ClientCookie>::parse_bytes(data) + .map(Self::ClientCookie)?, 16..=40 => <&Cookie>::parse_bytes(data).map(Self::Cookie)?, _ => return Err(ParseError), }, @@ -301,7 +301,7 @@ impl BuildBytes for EdnsOption<'_> { bytes = self.code().build_bytes(bytes)?; let data = match self { - Self::CookieRequest(this) => this.as_bytes(), + Self::ClientCookie(this) => this.as_bytes(), Self::Cookie(this) => this.as_bytes(), Self::ExtError(this) => this.as_bytes(), Self::Unknown(_, this) => this.as_bytes(), From c5f1552c5ee4745b68b8f5fd68bfbcce3bf64b4f Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 22 Jan 2025 15:22:01 +0100 Subject: [PATCH 096/167] [new_base/parse] Refactor '{Parse,Split}FromMessage' - The traits have been renamed to '{Parse,Split}MessageBytes' for consistency with '{Parse,Split}Bytes'. - The traits now consume 'message.contents' instead of the whole 'message', allowing them to be used in more contexts (e.g. the upcoming overhauled builder types). --- src/new_base/charstr.rs | 26 ++++------ src/new_base/name/reversed.rs | 17 +++--- src/new_base/name/unparsed.rs | 19 ++++--- src/new_base/parse/mod.rs | 63 +++++++++++------------ src/new_base/question.rs | 31 ++++++----- src/new_base/record.rs | 50 +++++++++--------- src/new_base/wire/size_prefixed.rs | 30 +++++------ src/new_edns/mod.rs | 26 ++++------ src/new_rdata/basic.rs | 83 ++++++++++++++---------------- src/new_rdata/mod.rs | 38 ++++++++------ 10 files changed, 179 insertions(+), 204 deletions(-) diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index 8df3c3d7c..3eab4d7cd 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -4,9 +4,8 @@ use core::fmt; use super::{ build::{self, BuildIntoMessage, BuildResult}, - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, - Message, }; //----------- CharStr -------------------------------------------------------- @@ -20,27 +19,22 @@ pub struct CharStr { //--- Parsing from DNS messages -impl<'a> SplitFromMessage<'a> for &'a CharStr { - fn split_from_message( - message: &'a Message, +impl<'a> SplitMessageBytes<'a> for &'a CharStr { + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let bytes = message.contents.get(start..).ok_or(ParseError)?; - let (this, rest) = Self::split_bytes(bytes)?; - Ok((this, bytes.len() - rest.len())) + Self::split_bytes(&contents[start..]) + .map(|(this, rest)| (this, contents.len() - start - rest.len())) } } -impl<'a> ParseFromMessage<'a> for &'a CharStr { - fn parse_from_message( - message: &'a Message, +impl<'a> ParseMessageBytes<'a> for &'a CharStr { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - message - .contents - .get(start..) - .ok_or(ParseError) - .and_then(Self::parse_bytes) + Self::parse_bytes(&contents[start..]) } } diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index f33451f3b..55a83e82e 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -10,9 +10,8 @@ use core::{ use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, - Message, }; use super::LabelIter; @@ -239,9 +238,9 @@ impl RevNameBuf { //--- Parsing from DNS messages -impl<'a> SplitFromMessage<'a> for RevNameBuf { - fn split_from_message( - message: &'a Message, +impl<'a> SplitMessageBytes<'a> for RevNameBuf { + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError> { // NOTE: The input may be controlled by an attacker. Compression @@ -251,7 +250,6 @@ impl<'a> SplitFromMessage<'a> for RevNameBuf { // disallow a name to point to data _after_ it. Standard name // compressors will never generate such pointers. - let contents = &message.contents; let mut buffer = Self::empty(); // Perform the first iteration early, to catch the end of the name. @@ -282,16 +280,15 @@ impl<'a> SplitFromMessage<'a> for RevNameBuf { } } -impl<'a> ParseFromMessage<'a> for RevNameBuf { - fn parse_from_message( - message: &'a Message, +impl<'a> ParseMessageBytes<'a> for RevNameBuf { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { // See 'split_from_message()' for details. The only differences are // in the range of the first iteration, and the check that the first // iteration exactly covers the input range. - let contents = &message.contents; let mut buffer = Self::empty(); // Perform the first iteration early, to catch the end of the name. diff --git a/src/new_base/name/unparsed.rs b/src/new_base/name/unparsed.rs index 828c92229..4a76bee6c 100644 --- a/src/new_base/name/unparsed.rs +++ b/src/new_base/name/unparsed.rs @@ -3,9 +3,8 @@ use domain_macros::*; use crate::new_base::{ - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::ParseError, - Message, }; //----------- UnparsedName --------------------------------------------------- @@ -77,12 +76,12 @@ impl UnparsedName { //--- Parsing from DNS messages -impl<'a> SplitFromMessage<'a> for &'a UnparsedName { - fn split_from_message( - message: &'a Message, +impl<'a> SplitMessageBytes<'a> for &'a UnparsedName { + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let bytes = message.contents.get(start..).ok_or(ParseError)?; + let bytes = &contents[start..]; let mut offset = 0; let offset = loop { match bytes[offset..] { @@ -120,12 +119,12 @@ impl<'a> SplitFromMessage<'a> for &'a UnparsedName { } } -impl<'a> ParseFromMessage<'a> for &'a UnparsedName { - fn parse_from_message( - message: &'a Message, +impl<'a> ParseMessageBytes<'a> for &'a UnparsedName { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - let bytes = message.contents.get(start..).ok_or(ParseError)?; + let bytes = &contents[start..]; let mut offset = 0; loop { match bytes[offset..] { diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index cbe949681..03ac16995 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -1,69 +1,68 @@ //! Parsing DNS messages from the wire format. //! -//! This module provides [`ParseFromMessage`] and [`SplitFromMessage`], which -//! are specializations of [`ParseBytes`] and [`SplitBytes`] to DNS messages. -//! When parsing data within a DNS message, these traits allow access to all -//! preceding bytes in the message so that compressed names can be resolved. +//! This module provides [`ParseMessageBytes`] and [`SplitMessageBytes`], +//! which are specializations of [`ParseBytes`] and [`SplitBytes`] to DNS +//! messages. When parsing data within a DNS message, these traits allow +//! access to all preceding bytes in the message so that compressed names can +//! be resolved. //! //! [`ParseBytes`]: super::wire::ParseBytes //! [`SplitBytes`]: super::wire::SplitBytes pub use super::wire::ParseError; -use super::{ - wire::{ParseBytesByRef, SplitBytesByRef}, - Message, -}; +use super::wire::{ParseBytesByRef, SplitBytesByRef}; -//----------- Message-aware parsing traits ----------------------------------- +//----------- Message parsing traits ----------------------------------------- /// A type that can be parsed from a DNS message. -pub trait SplitFromMessage<'a>: Sized + ParseFromMessage<'a> { +pub trait SplitMessageBytes<'a>: Sized + ParseMessageBytes<'a> { /// Parse a value from the start of a byte string within a DNS message. /// - /// The byte string to parse is `message.contents[start..]`. The previous - /// data in the message can be used for resolving compressed names. + /// The contents of the DNS message is provided as `contents`. + /// `contents[start..]` is the beginning of the input to be parsed. The + /// earlier bytes are provided for resolving compressed domain names. /// /// If parsing is successful, the parsed value and the offset for the rest /// of the input are returned. If `len` bytes were parsed to form `self`, /// `start + len` should be the returned offset. - fn split_from_message( - message: &'a Message, + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError>; } -/// A type that can be parsed from a string in a DNS message. -pub trait ParseFromMessage<'a>: Sized { - /// Parse a value from a byte string within a DNS message. +/// A type that can be parsed from bytes in a DNS message. +pub trait ParseMessageBytes<'a>: Sized { + /// Parse a value from bytes in a DNS message. /// - /// The byte string to parse is `message.contents[start..]`. The previous - /// data in the message can be used for resolving compressed names. + /// The contents of the DNS message (up to and including the actual bytes + /// to be parsed) is provided as `contents`. `contents[start..]` is the + /// input to be parsed. The earlier bytes are provided for resolving + /// compressed domain names. /// /// If parsing is successful, the parsed value is returned. - fn parse_from_message( - message: &'a Message, + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result; } -impl<'a, T: ?Sized + SplitBytesByRef> SplitFromMessage<'a> for &'a T { - fn split_from_message( - message: &'a Message, +impl<'a, T: ?Sized + SplitBytesByRef> SplitMessageBytes<'a> for &'a T { + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let bytes = message.contents.get(start..).ok_or(ParseError)?; - let (this, rest) = T::split_bytes_by_ref(bytes)?; - Ok((this, bytes.len() - rest.len())) + T::split_bytes_by_ref(&contents[start..]) + .map(|(this, rest)| (this, contents.len() - start - rest.len())) } } -impl<'a, T: ?Sized + ParseBytesByRef> ParseFromMessage<'a> for &'a T { - fn parse_from_message( - message: &'a Message, +impl<'a, T: ?Sized + ParseBytesByRef> ParseMessageBytes<'a> for &'a T { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - let bytes = message.contents.get(start..).ok_or(ParseError)?; - T::parse_bytes_by_ref(bytes) + T::parse_bytes_by_ref(&contents[start..]) } } diff --git a/src/new_base/question.rs b/src/new_base/question.rs index e4602a4b6..04dfa2d9e 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -5,9 +5,8 @@ use domain_macros::*; use super::{ build::{self, BuildIntoMessage, BuildResult}, name::RevNameBuf, - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{AsBytes, ParseError, U16}, - Message, }; //----------- Question ------------------------------------------------------- @@ -43,32 +42,32 @@ impl Question { //--- Parsing from DNS messages -impl<'a, N> SplitFromMessage<'a> for Question +impl<'a, N> SplitMessageBytes<'a> for Question where - N: SplitFromMessage<'a>, + N: SplitMessageBytes<'a>, { - fn split_from_message( - message: &'a Message, + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let (qname, rest) = N::split_from_message(message, start)?; - let (&qtype, rest) = <&QType>::split_from_message(message, rest)?; - let (&qclass, rest) = <&QClass>::split_from_message(message, rest)?; + let (qname, rest) = N::split_message_bytes(contents, start)?; + let (&qtype, rest) = <&QType>::split_message_bytes(contents, rest)?; + let (&qclass, rest) = <&QClass>::split_message_bytes(contents, rest)?; Ok((Self::new(qname, qtype, qclass), rest)) } } -impl<'a, N> ParseFromMessage<'a> for Question +impl<'a, N> ParseMessageBytes<'a> for Question where - N: SplitFromMessage<'a>, + N: SplitMessageBytes<'a>, { - fn parse_from_message( - message: &'a Message, + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - let (qname, rest) = N::split_from_message(message, start)?; - let (&qtype, rest) = <&QType>::split_from_message(message, rest)?; - let &qclass = <&QClass>::parse_from_message(message, rest)?; + let (qname, rest) = N::split_message_bytes(contents, start)?; + let (&qtype, rest) = <&QType>::split_message_bytes(contents, rest)?; + let &qclass = <&QClass>::parse_message_bytes(contents, rest)?; Ok(Self::new(qname, qtype, qclass)) } } diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 742a66977..0c0e6fa3c 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -5,12 +5,11 @@ use core::{borrow::Borrow, ops::Deref}; use super::{ build::{self, BuildIntoMessage, BuildResult}, name::RevNameBuf, - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{ AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, SizePrefixed, SplitBytes, SplitBytesByRef, TruncationError, U16, U32, }, - Message, }; //----------- Record --------------------------------------------------------- @@ -60,44 +59,44 @@ impl Record { //--- Parsing from DNS messages -impl<'a, N, D> SplitFromMessage<'a> for Record +impl<'a, N, D> SplitMessageBytes<'a> for Record where - N: SplitFromMessage<'a>, + N: SplitMessageBytes<'a>, D: ParseRecordData<'a>, { - fn split_from_message( - message: &'a Message, + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let (rname, rest) = N::split_from_message(message, start)?; - let (&rtype, rest) = <&RType>::split_from_message(message, rest)?; - let (&rclass, rest) = <&RClass>::split_from_message(message, rest)?; - let (&ttl, rest) = <&TTL>::split_from_message(message, rest)?; + let (rname, rest) = N::split_message_bytes(contents, start)?; + let (&rtype, rest) = <&RType>::split_message_bytes(contents, rest)?; + let (&rclass, rest) = <&RClass>::split_message_bytes(contents, rest)?; + let (&ttl, rest) = <&TTL>::split_message_bytes(contents, rest)?; let rdata_start = rest; let (_, rest) = - <&SizePrefixed<[u8]>>::split_from_message(message, rest)?; - let message = message.slice_to(rest); - let rdata = D::parse_record_data(message, rdata_start, rtype)?; + <&SizePrefixed<[u8]>>::split_message_bytes(contents, rest)?; + let rdata = + D::parse_record_data(&contents[..rest], rdata_start, rtype)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) } } -impl<'a, N, D> ParseFromMessage<'a> for Record +impl<'a, N, D> ParseMessageBytes<'a> for Record where - N: SplitFromMessage<'a>, + N: SplitMessageBytes<'a>, D: ParseRecordData<'a>, { - fn parse_from_message( - message: &'a Message, + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - let (rname, rest) = N::split_from_message(message, start)?; - let (&rtype, rest) = <&RType>::split_from_message(message, rest)?; - let (&rclass, rest) = <&RClass>::split_from_message(message, rest)?; - let (&ttl, rest) = <&TTL>::split_from_message(message, rest)?; - let _ = <&SizePrefixed<[u8]>>::parse_from_message(message, rest)?; - let rdata = D::parse_record_data(message, rest, rtype)?; + let (rname, rest) = N::split_message_bytes(contents, start)?; + let (&rtype, rest) = <&RType>::split_message_bytes(contents, rest)?; + let (&rclass, rest) = <&RClass>::split_message_bytes(contents, rest)?; + let (&ttl, rest) = <&TTL>::split_message_bytes(contents, rest)?; + let _ = <&SizePrefixed<[u8]>>::parse_message_bytes(contents, rest)?; + let rdata = D::parse_record_data(contents, rest, rtype)?; Ok(Self::new(rname, rtype, rclass, ttl, rdata)) } @@ -305,12 +304,11 @@ pub struct TTL { pub trait ParseRecordData<'a>: Sized { /// Parse DNS record data of the given type from a DNS message. fn parse_record_data( - message: &'a Message, + contents: &'a [u8], start: usize, rtype: RType, ) -> Result { - let bytes = message.contents.get(start..).ok_or(ParseError)?; - Self::parse_record_data_bytes(bytes, rtype) + Self::parse_record_data_bytes(&contents[start..], rtype) } /// Parse DNS record data of the given type from a byte string. diff --git a/src/new_base/wire/size_prefixed.rs b/src/new_base/wire/size_prefixed.rs index 751a57395..5ac9effa9 100644 --- a/src/new_base/wire/size_prefixed.rs +++ b/src/new_base/wire/size_prefixed.rs @@ -7,8 +7,7 @@ use core::{ use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, - parse::{ParseFromMessage, SplitFromMessage}, - Message, + parse::{ParseMessageBytes, SplitMessageBytes}, }; use super::{ @@ -110,31 +109,28 @@ impl AsMut for SizePrefixed { //--- Parsing from DNS messages -impl<'b, T: ParseFromMessage<'b>> ParseFromMessage<'b> for SizePrefixed { - fn parse_from_message( - message: &'b Message, +impl<'b, T: ParseMessageBytes<'b>> ParseMessageBytes<'b> for SizePrefixed { + fn parse_message_bytes( + contents: &'b [u8], start: usize, ) -> Result { - let (&size, rest) = <&U16>::split_from_message(message, start)?; - if rest + size.get() as usize != message.contents.len() { + let (&size, rest) = <&U16>::split_message_bytes(contents, start)?; + if rest + size.get() as usize != contents.len() { return Err(ParseError); } - T::parse_from_message(message, rest).map(Self::new) + T::parse_message_bytes(contents, rest).map(Self::new) } } -impl<'b, T: ParseFromMessage<'b>> SplitFromMessage<'b> for SizePrefixed { - fn split_from_message( - message: &'b Message, +impl<'b, T: ParseMessageBytes<'b>> SplitMessageBytes<'b> for SizePrefixed { + fn split_message_bytes( + contents: &'b [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let (&size, rest) = <&U16>::split_from_message(message, start)?; + let (&size, rest) = <&U16>::split_message_bytes(contents, start)?; let (start, rest) = (rest, rest + size.get() as usize); - if rest > message.contents.len() { - return Err(ParseError); - } - let message = message.slice_to(rest); - let data = T::parse_from_message(message, start)?; + let contents = contents.get(..rest).ok_or(ParseError)?; + let data = T::parse_message_bytes(contents, start)?; Ok((Self::new(data), rest)) } } diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 8bc3c55b6..96233221d 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -8,12 +8,11 @@ use domain_macros::*; use crate::{ new_base::{ - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{ AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, SizePrefixed, SplitBytes, TruncationError, U16, }, - Message, }, new_rdata::Opt, }; @@ -49,27 +48,22 @@ pub struct EdnsRecord<'a> { //--- Parsing from DNS messages -impl<'a> SplitFromMessage<'a> for EdnsRecord<'a> { - fn split_from_message( - message: &'a Message, +impl<'a> SplitMessageBytes<'a> for EdnsRecord<'a> { + fn split_message_bytes( + contents: &'a [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let bytes = message.contents.get(start..).ok_or(ParseError)?; - let (this, rest) = Self::split_bytes(bytes)?; - Ok((this, message.contents.len() - rest.len())) + Self::split_bytes(&contents[start..]) + .map(|(this, rest)| (this, contents.len() - start - rest.len())) } } -impl<'a> ParseFromMessage<'a> for EdnsRecord<'a> { - fn parse_from_message( - message: &'a Message, +impl<'a> ParseMessageBytes<'a> for EdnsRecord<'a> { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - message - .contents - .get(start..) - .ok_or(ParseError) - .and_then(Self::parse_bytes) + Self::parse_bytes(&contents[start..]) } } diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 456da881c..43d089100 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -14,9 +14,9 @@ use domain_macros::*; use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{AsBytes, ParseBytes, ParseError, SplitBytes, U16, U32}, - CharStr, Message, Serial, + CharStr, Serial, }; //----------- A -------------------------------------------------------------- @@ -114,12 +114,12 @@ pub struct Ns { //--- Parsing from DNS messages -impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ns { - fn parse_from_message( - message: &'a Message, +impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Ns { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - N::parse_from_message(message, start).map(|name| Self { name }) + N::parse_message_bytes(contents, start).map(|name| Self { name }) } } @@ -155,12 +155,12 @@ pub struct CName { //--- Parsing from DNS messages -impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for CName { - fn parse_from_message( - message: &'a Message, +impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for CName { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - N::parse_from_message(message, start).map(|name| Self { name }) + N::parse_message_bytes(contents, start).map(|name| Self { name }) } } @@ -211,18 +211,18 @@ pub struct Soa { //--- Parsing from DNS messages -impl<'a, N: SplitFromMessage<'a>> ParseFromMessage<'a> for Soa { - fn parse_from_message( - message: &'a Message, +impl<'a, N: SplitMessageBytes<'a>> ParseMessageBytes<'a> for Soa { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - let (mname, rest) = N::split_from_message(message, start)?; - let (rname, rest) = N::split_from_message(message, rest)?; - let (&serial, rest) = <&Serial>::split_from_message(message, rest)?; - let (&refresh, rest) = <&U32>::split_from_message(message, rest)?; - let (&retry, rest) = <&U32>::split_from_message(message, rest)?; - let (&expire, rest) = <&U32>::split_from_message(message, rest)?; - let &minimum = <&U32>::parse_from_message(message, rest)?; + let (mname, rest) = N::split_message_bytes(contents, start)?; + let (rname, rest) = N::split_message_bytes(contents, rest)?; + let (&serial, rest) = <&Serial>::split_message_bytes(contents, rest)?; + let (&refresh, rest) = <&U32>::split_message_bytes(contents, rest)?; + let (&retry, rest) = <&U32>::split_message_bytes(contents, rest)?; + let (&expire, rest) = <&U32>::split_message_bytes(contents, rest)?; + let &minimum = <&U32>::parse_message_bytes(contents, rest)?; Ok(Self { mname, @@ -330,12 +330,12 @@ pub struct Ptr { //--- Parsing from DNS messages -impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Ptr { - fn parse_from_message( - message: &'a Message, +impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Ptr { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - N::parse_from_message(message, start).map(|name| Self { name }) + N::parse_message_bytes(contents, start).map(|name| Self { name }) } } @@ -361,16 +361,12 @@ pub struct HInfo<'a> { //--- Parsing from DNS messages -impl<'a> ParseFromMessage<'a> for HInfo<'a> { - fn parse_from_message( - message: &'a Message, +impl<'a> ParseMessageBytes<'a> for HInfo<'a> { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - message - .contents - .get(start..) - .ok_or(ParseError) - .and_then(Self::parse_bytes) + Self::parse_bytes(&contents[start..]) } } @@ -414,13 +410,14 @@ pub struct Mx { //--- Parsing from DNS messages -impl<'a, N: ParseFromMessage<'a>> ParseFromMessage<'a> for Mx { - fn parse_from_message( - message: &'a Message, +impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Mx { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - let (&preference, rest) = <&U16>::split_from_message(message, start)?; - let exchange = N::parse_from_message(message, rest)?; + let (&preference, rest) = + <&U16>::split_message_bytes(contents, start)?; + let exchange = N::parse_message_bytes(contents, rest)?; Ok(Self { preference, exchange, @@ -471,16 +468,12 @@ impl Txt { //--- Parsing from DNS messages -impl<'a> ParseFromMessage<'a> for &'a Txt { - fn parse_from_message( - message: &'a Message, +impl<'a> ParseMessageBytes<'a> for &'a Txt { + fn parse_message_bytes( + contents: &'a [u8], start: usize, ) -> Result { - message - .contents - .get(start..) - .ok_or(ParseError) - .and_then(Self::parse_bytes) + Self::parse_bytes(&contents[start..]) } } diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index e4b94a538..70f041240 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -4,9 +4,9 @@ use domain_macros::*; use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, - parse::{ParseFromMessage, SplitFromMessage}, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, - Message, ParseRecordData, RType, + ParseRecordData, RType, }; //----------- Concrete record data types ------------------------------------- @@ -67,42 +67,48 @@ pub enum RecordData<'a, N> { impl<'a, N> ParseRecordData<'a> for RecordData<'a, N> where - N: SplitBytes<'a> + SplitFromMessage<'a>, + N: SplitBytes<'a> + SplitMessageBytes<'a>, { fn parse_record_data( - message: &'a Message, + contents: &'a [u8], start: usize, rtype: RType, ) -> Result { match rtype { - RType::A => <&A>::parse_from_message(message, start).map(Self::A), - RType::NS => Ns::parse_from_message(message, start).map(Self::Ns), + RType::A => { + <&A>::parse_message_bytes(contents, start).map(Self::A) + } + RType::NS => { + Ns::parse_message_bytes(contents, start).map(Self::Ns) + } RType::CNAME => { - CName::parse_from_message(message, start).map(Self::CName) + CName::parse_message_bytes(contents, start).map(Self::CName) } RType::SOA => { - Soa::parse_from_message(message, start).map(Self::Soa) + Soa::parse_message_bytes(contents, start).map(Self::Soa) } RType::WKS => { - <&Wks>::parse_from_message(message, start).map(Self::Wks) + <&Wks>::parse_message_bytes(contents, start).map(Self::Wks) } RType::PTR => { - Ptr::parse_from_message(message, start).map(Self::Ptr) + Ptr::parse_message_bytes(contents, start).map(Self::Ptr) } RType::HINFO => { - HInfo::parse_from_message(message, start).map(Self::HInfo) + HInfo::parse_message_bytes(contents, start).map(Self::HInfo) + } + RType::MX => { + Mx::parse_message_bytes(contents, start).map(Self::Mx) } - RType::MX => Mx::parse_from_message(message, start).map(Self::Mx), RType::TXT => { - <&Txt>::parse_from_message(message, start).map(Self::Txt) + <&Txt>::parse_message_bytes(contents, start).map(Self::Txt) } RType::AAAA => { - <&Aaaa>::parse_from_message(message, start).map(Self::Aaaa) + <&Aaaa>::parse_message_bytes(contents, start).map(Self::Aaaa) } RType::OPT => { - <&Opt>::parse_from_message(message, start).map(Self::Opt) + <&Opt>::parse_message_bytes(contents, start).map(Self::Opt) } - _ => <&UnknownRecordData>::parse_from_message(message, start) + _ => <&UnknownRecordData>::parse_message_bytes(contents, start) .map(|data| Self::Unknown(rtype, data)), } } From e2729478571af3f96ac227fa31516758bbe6b4ed Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 23 Jan 2025 16:02:36 +0100 Subject: [PATCH 097/167] [new_base/build] Overhaul message building - 'BuilderContext' now tracks the last question / record in a message, allowing its building to be recovered in the future (WIP). - 'MessageBuilder' inlines the 'Builder' and just stores a mutable reference to a 'Message' for simplicity. - 'Builder' no longer gives access to the message header, and uses '&UnsafeCell<[u8]>' to represent the message contents. --- src/new_base/build/builder.rs | 309 ++++++++-------------------- src/new_base/build/context.rs | 135 +++++++++++++ src/new_base/build/message.rs | 290 ++++++++++++--------------- src/new_base/build/mod.rs | 51 +---- src/new_base/build/question.rs | 106 ++++++++++ src/new_base/build/record.rs | 310 +++++++++++++++++------------ src/new_base/wire/size_prefixed.rs | 6 +- 7 files changed, 647 insertions(+), 560 deletions(-) create mode 100644 src/new_base/build/context.rs create mode 100644 src/new_base/build/question.rs diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 16444349c..edcc9b543 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -1,18 +1,18 @@ //! A builder for DNS messages. use core::{ - marker::PhantomData, + cell::UnsafeCell, mem::ManuallyDrop, - ptr::{self, NonNull}, + ptr::{self}, + slice, }; use crate::new_base::{ name::RevName, - wire::{BuildBytes, ParseBytesByRef, TruncationError}, - Header, Message, + wire::{BuildBytes, TruncationError}, }; -use super::BuildCommitted; +use super::{BuildCommitted, BuilderContext}; //----------- Builder -------------------------------------------------------- @@ -60,144 +60,49 @@ use super::BuildCommitted; /// will be reverted on failure). For this, we have [`delegate()`]. /// /// [`delegate()`]: Self::delegate() -/// -/// For example: -/// -/// ``` -/// # use domain::new_base::build::{Builder, BuildResult, BuilderContext}; -/// -/// /// A build function with the conventional type signature. -/// fn foo(mut builder: Builder<'_>) -> BuildResult { -/// // Content added by the parent builder is considered committed. -/// assert_eq!(builder.committed(), b"hi! "); -/// -/// // Append some content to the builder. -/// builder.append_bytes(b"foo!")?; -/// -/// // Try appending a very long string, which can't fit. -/// builder.append_bytes(b"helloworldthisiswaytoobig")?; -/// -/// Ok(builder.commit()) -/// } -/// -/// // Construct a builder for a particular buffer. -/// let mut buffer = [0u8; 20]; -/// let mut context = BuilderContext::default(); -/// let mut builder = Builder::new(&mut buffer, &mut context); -/// -/// // Try appending some content to the builder. -/// builder.append_bytes(b"hi! ").unwrap(); -/// assert_eq!(builder.appended(), b"hi! "); -/// -/// // Try calling 'foo' -- note that it will fail. -/// // Note that we delegated the builder. -/// foo(builder.delegate()).unwrap_err(); -/// -/// // No partial content was written. -/// assert_eq!(builder.appended(), b"hi! "); -/// ``` pub struct Builder<'b> { - /// The message being built. + /// The contents of the built message. /// - /// The message is divided into four parts: + /// The buffer is divided into three parts: /// - /// - The message header (borrowed mutably by this type). /// - Committed message contents (borrowed *immutably* by this type). /// - Appended message contents (borrowed mutably by this type). /// - Uninitialized message contents (borrowed mutably by this type). - message: NonNull, - - _message: PhantomData<&'b mut Message>, + contents: &'b UnsafeCell<[u8]>, /// Context for building. context: &'b mut BuilderContext, - /// The commit point of this builder. + /// The start point of this builder. /// /// Message contents up to this point are committed and cannot be removed /// by this builder. Message contents following this (up to the size in /// the builder context) are appended but uncommitted. - commit: usize, + start: usize, } -/// # Initialization -/// -/// In order to begin building a DNS message: -/// -/// ``` -/// # use domain::new_base::build::{Builder, BuilderContext}; -/// -/// // Allocate a slice of 'u8's somewhere. -/// let mut buffer = [0u8; 20]; -/// -/// // Obtain a builder context. -/// // -/// // The value doesn't matter, it will be overwritten. -/// let mut context = BuilderContext::default(); -/// -/// // Construct the actual 'Builder'. -/// let builder = Builder::new(&mut buffer, &mut context); -/// -/// assert!(builder.committed().is_empty()); -/// assert!(builder.appended().is_empty()); -/// ``` impl<'b> Builder<'b> { - /// Create a [`Builder`] for a new, empty DNS message. - /// - /// The message header is left uninitialized. Use [`Self::header_mut()`] - /// to initialize it. The message contents are completely empty. - /// - /// The provided builder context will be overwritten with a default state. - /// - /// # Panics - /// - /// Panics if the buffer is less than 12 bytes long (which is the minimum - /// possible size for a DNS message). - pub fn new( - buffer: &'b mut [u8], - context: &'b mut BuilderContext, - ) -> Self { - let message = Message::parse_bytes_by_mut(buffer) - .expect("The buffure must be at least 12 bytes in size"); - context.size = 0; - - // SAFETY: 'message' and 'context' are now consistent. - unsafe { Self::from_raw_parts(message.into(), context, 0) } - } - /// Construct a [`Builder`] from raw parts. /// - /// The provided components must originate from [`into_raw_parts()`], and - /// none of the components can be modified since they were extracted. - /// - /// [`into_raw_parts()`]: Self::into_raw_parts() - /// - /// This method is useful when overcoming limitations in lifetimes or - /// borrow checking, or when a builder has to be constructed from another - /// with specific characteristics. - /// /// # Safety /// - /// The expression `from_raw_parts(message, context, commit)` is sound if + /// The expression `from_raw_parts(contents, context, start)` is sound if /// and only if all of the following conditions are satisfied: /// - /// - `message` is a valid reference for the lifetime `'b`. - /// - `message.header` is mutably borrowed for `'b`. - /// - `message.contents[..commit]` is immutably borrowed for `'b`. - /// - `message.contents[commit..]` is mutably borrowed for `'b`. + /// - `message[..start]` is immutably borrowed for `'b`. + /// - `message[start..]` is mutably borrowed for `'b`. /// /// - `message` and `context` originate from the same builder. - /// - `commit <= context.size() <= message.contents.len()`. + /// - `start <= context.size() <= message.len()`. pub unsafe fn from_raw_parts( - message: NonNull, + contents: &'b UnsafeCell<[u8]>, context: &'b mut BuilderContext, - commit: usize, + start: usize, ) -> Self { Self { - message, - _message: PhantomData, + contents, context, - commit, + start, } } } @@ -210,38 +115,30 @@ impl<'b> Builder<'b> { /// ```text /// name | position /// --------------+--------- -/// header | -/// committed | 0 .. commit -/// appended | commit .. size -/// uninitialized | size .. limit +/// committed | 0 .. start +/// appended | start .. offset +/// uninitialized | offset .. limit /// inaccessible | limit .. /// ``` /// -/// The DNS message header can be modified at any time. It is made available -/// through [`header()`] and [`header_mut()`]. In general, it is inadvisable -/// to change the section counts arbitrarily (although it will not cause -/// undefined behaviour). -/// -/// [`header()`]: Self::header() -/// [`header_mut()`]: Self::header_mut() -/// /// The committed content of the builder is immutable, and is available to /// reference, through [`committed()`], for the lifetime `'b`. /// /// [`committed()`]: Self::committed() /// -/// The appended content of the builder is made available via [`appended()`]. -/// It is content that has been added by this builder, but that has not yet -/// been committed. When the [`Builder`] is dropped, this content is removed -/// (it becomes uninitialized). Appended content can be modified, but any -/// compressed names within it have to be handled with great care; they can -/// only be modified by removing them entirely (by rewinding the builder, -/// using [`rewind()`]) and building them again. When compressed names are -/// guaranteed to not be modified, [`appended_mut()`] can be used. +/// The appended but uncommitted content of the builder is made available via +/// [`uncommitted_mut()`]. It is content that has been added by this builder, +/// but that has not yet been committed. When the [`Builder`] is dropped, +/// this content is removed (it becomes uninitialized). Appended content can +/// be modified, but any compressed names within it have to be handled with +/// great care; they can only be modified by removing them entirely (by +/// rewinding the builder, using [`rewind()`]) and building them again. When +/// compressed names are guaranteed to not be modified, [`uncommitted_mut()`] +/// can be used. /// /// [`appended()`]: Self::appended() /// [`rewind()`]: Self::rewind() -/// [`appended_mut()`]: Self::appended_mut() +/// [`uncommitted_mut()`]: Self::uncommitted_mut() /// /// The uninitialized space in the builder will be written to when appending /// new content. It can be accessed directly, in case that is more efficient @@ -259,37 +156,31 @@ impl<'b> Builder<'b> { /// /// [`limit_to()`]: Self::limit_to() impl<'b> Builder<'b> { - /// The header of the DNS message. - pub fn header(&self) -> &Header { - // SAFETY: 'message.header' is mutably borrowed by 'self'. - unsafe { &(*self.message.as_ptr()).header } - } - - /// The header of the DNS message, mutably. - /// - /// It is possible to modify the section counts arbitrarily through this - /// method; while doing so cannot cause undefined behaviour, it is not - /// recommended. - pub fn header_mut(&mut self) -> &mut Header { - // SAFETY: 'message.header' is mutably borrowed by 'self'. - unsafe { &mut (*self.message.as_ptr()).header } - } - /// Committed message contents. pub fn committed(&self) -> &'b [u8] { - // SAFETY: 'message.contents[..commit]' is immutably borrowed by - // 'self'. - unsafe { &(*self.message.as_ptr()).contents[..self.commit] } + let message = self.contents.get().cast_const().cast(); + // SAFETY: 'message[..start]' is immutably borrowed. + unsafe { slice::from_raw_parts(message, self.start) } + } + + /// Appended (and committed) message contents. + pub fn appended(&self) -> &[u8] { + let message = self.contents.get().cast_const().cast(); + // SAFETY: 'message[..offset]' is (im)mutably borrowed. + unsafe { slice::from_raw_parts(message, self.context.size) } } /// The appended but uncommitted contents of the message. /// /// The builder can modify or rewind these contents, so they are offered /// with a short lifetime. - pub fn appended(&self) -> &[u8] { - // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. - let range = self.commit..self.context.size; - unsafe { &(*self.message.as_ptr()).contents[range] } + pub fn uncommitted(&self) -> &[u8] { + let message = self.contents.get().cast::().cast_const(); + // SAFETY: It is guaranteed that 'start <= message.len()'. + let message = unsafe { message.offset(self.start as isize) }; + let size = self.context.size - self.start; + // SAFETY: 'message[start..]' is mutably borrowed. + unsafe { slice::from_raw_parts(message, size) } } /// The appended but uncommitted contents of the message, mutably. @@ -298,10 +189,13 @@ impl<'b> Builder<'b> { /// /// The caller must not modify any compressed names among these bytes. /// This can invalidate name compression state. - pub unsafe fn appended_mut(&mut self) -> &mut [u8] { - // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. - let range = self.commit..self.context.size; - unsafe { &mut (*self.message.as_ptr()).contents[range] } + pub unsafe fn uncommitted_mut(&mut self) -> &mut [u8] { + let message = self.contents.get().cast::(); + // SAFETY: It is guaranteed that 'start <= message.len()'. + let message = unsafe { message.offset(self.start as isize) }; + let size = self.context.size - self.start; + // SAFETY: 'message[start..]' is mutably borrowed. + unsafe { slice::from_raw_parts_mut(message, size) } } /// Uninitialized space in the message buffer. @@ -310,39 +204,12 @@ impl<'b> Builder<'b> { /// should be treated as appended content in the message, call /// [`self.mark_appended(n)`](Self::mark_appended()). pub fn uninitialized(&mut self) -> &mut [u8] { - // SAFETY: 'message.contents[commit..]' is mutably borrowed by 'self'. - unsafe { &mut (*self.message.as_ptr()).contents[self.context.size..] } - } - - /// The message with all committed contents. - /// - /// The header of the message can be modified by the builder, so the - /// returned reference has a short lifetime. The message contents can be - /// borrowed for a longer lifetime -- see [`committed()`]. The message - /// does not include content that has been appended but not committed. - /// - /// [`committed()`]: Self::committed() - pub fn message(&self) -> &Message { - // SAFETY: All of 'message' can be immutably borrowed by 'self'. - unsafe { self.message.as_ref() }.slice_to(self.commit) - } - - /// The message including any uncommitted contents. - pub fn cur_message(&self) -> &Message { - // SAFETY: All of 'message' can be immutably borrowed by 'self'. - unsafe { self.message.as_ref() }.slice_to(self.context.size) - } - - /// A pointer to the message, including any uncommitted contents. - /// - /// The first `commit` bytes of the message contents (also provided by - /// [`Self::committed()`]) are immutably borrowed for the lifetime `'b`. - /// The remainder of the message is initialized and borrowed by `self`. - pub fn cur_message_ptr(&mut self) -> NonNull { - let message = self.message.as_ptr(); - let size = self.context.size; - let message = unsafe { Message::ptr_slice_to(message, size) }; - unsafe { NonNull::new_unchecked(message) } + let message = self.contents.get().cast::(); + // SAFETY: It is guaranteed that 'size <= message.len()'. + let message = unsafe { message.offset(self.context.size as isize) }; + let size = self.max_size() - self.context.size; + // SAFETY: 'message[size..]' is mutably borrowed. + unsafe { slice::from_raw_parts_mut(message, size) } } /// The builder context. @@ -356,7 +223,15 @@ impl<'b> Builder<'b> { /// initialized. The content before this point has been committed and is /// immutable. The builder can be rewound up to this point. pub fn start(&self) -> usize { - self.commit + self.start + } + + /// The append point of this builder. + /// + /// This is the offset into the message contents at which new data will be + /// written. The content after this point is uninitialized. + pub fn offset(&self) -> usize { + self.context.size } /// The size limit of this builder. @@ -365,12 +240,11 @@ impl<'b> Builder<'b> { /// [`TruncationError`]s will occur. The limit can be tightened using /// [`limit_to()`](Self::limit_to()). pub fn max_size(&self) -> usize { - // SAFETY: 'Message' ends with a slice DST, and so references to it - // hold the length of that slice; we can cast it to another slice type - // and the pointer representation is unchanged. By using a slice type - // of ZST elements, aliasing is impossible, and it can be dereferenced + // SAFETY: We can cast 'contents' to another slice type and the + // pointer representation is unchanged. By using a slice type of ZST + // elements, aliasing is impossible, and it can be dereferenced // safely. - unsafe { &*(self.message.as_ptr() as *mut [()]) }.len() + unsafe { &*(self.contents.get() as *mut [()]) }.len() } /// Decompose this builder into raw parts. @@ -389,14 +263,14 @@ impl<'b> Builder<'b> { /// The builder can be recomposed with [`Self::from_raw_parts()`]. pub fn into_raw_parts( self, - ) -> (NonNull, &'b mut BuilderContext, usize) { + ) -> (&'b UnsafeCell<[u8]>, &'b mut BuilderContext, usize) { // NOTE: The context has to be moved out carefully. - let (message, commit) = (self.message, self.commit); + let (contents, start) = (self.contents, self.start); let this = ManuallyDrop::new(self); let this = (&*this) as *const Self; // SAFETY: 'this' is a valid object that can be moved out of. let context = unsafe { ptr::read(ptr::addr_of!((*this).context)) }; - (message, context, commit) + (contents, context, start) } } @@ -437,7 +311,7 @@ impl<'b> Builder<'b> { impl Builder<'_> { /// Rewind the builder, removing all uncommitted content. pub fn rewind(&mut self) { - self.context.size = self.commit; + self.context.size = self.start; } /// Commit the changes made by this builder. @@ -447,7 +321,7 @@ impl Builder<'_> { /// this method on success paths. pub fn commit(mut self) -> BuildCommitted { // Update 'commit' so that the drop glue is a no-op. - self.commit = self.context.size; + self.start = self.context.size; BuildCommitted } @@ -459,9 +333,11 @@ impl Builder<'_> { /// limit, a [`TruncationError`] is returned. pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { if self.context.size <= size { - let message = self.message.as_ptr(); - let message = unsafe { Message::ptr_slice_to(message, size) }; - self.message = unsafe { NonNull::new_unchecked(message) }; + let message = self.contents.get().cast::(); + debug_assert!(size <= self.max_size()); + self.contents = unsafe { + &*(ptr::slice_from_raw_parts_mut(message, size) as *const _) + }; Ok(()) } else { Err(TruncationError) @@ -490,7 +366,7 @@ impl Builder<'_> { pub fn delegate(&mut self) -> Builder<'_> { let commit = self.context.size; unsafe { - Builder::from_raw_parts(self.message, &mut *self.context, commit) + Builder::from_raw_parts(self.contents, &mut *self.context, commit) } } @@ -550,18 +426,3 @@ unsafe impl Send for Builder<'_> {} // SAFETY: Only parts of the referenced message that are borrowed immutably // can be accessed through an immutable reference to `self`. unsafe impl Sync for Builder<'_> {} - -//----------- BuilderContext ------------------------------------------------- - -/// Context for building a DNS message. -/// -/// This type holds auxiliary information necessary for building DNS messages, -/// e.g. name compression state. To construct it, call [`default()`]. -/// -/// [`default()`]: Self::default() -#[derive(Clone, Debug, Default)] -pub struct BuilderContext { - // TODO: Name compression. - /// The current size of the message contents. - size: usize, -} diff --git a/src/new_base/build/context.rs b/src/new_base/build/context.rs new file mode 100644 index 000000000..2f7f43da1 --- /dev/null +++ b/src/new_base/build/context.rs @@ -0,0 +1,135 @@ +//! Context for building DNS messages. + +//----------- BuilderContext ------------------------------------------------- + +/// Context for building a DNS message. +/// +/// This type holds auxiliary information necessary for building DNS messages, +/// e.g. name compression state. To construct it, call [`default()`]. +/// +/// [`default()`]: Self::default() +#[derive(Clone, Debug, Default)] +pub struct BuilderContext { + // TODO: Name compression. + /// The current size of the message contents. + pub size: usize, + + /// The state of the DNS message. + pub state: MessageState, +} + +//----------- MessageState --------------------------------------------------- + +/// The state of a DNS message being built. +/// +/// A DNS message consists of a header, questions, answers, authorities, and +/// additionals. [`MessageState`] remembers the start position of the last +/// question or record in the message, allowing it to be modifying or removed +/// (for additional flexibility in the building process). +#[derive(Clone, Debug, Default)] +pub enum MessageState { + /// Questions are being built. + /// + /// The message already contains zero or more DNS questions. If there is + /// a last DNS question, its start position is unknown, so it cannot be + /// modified or removed. + /// + /// This is the default state for an empty message. + #[default] + Questions, + + /// A question is being built. + /// + /// The message contains one or more DNS questions. The last question can + /// be modified or truncated. + MidQuestion { + /// The offset of the question name. + /// + /// The offset is measured from the start of the message contents. + name: u16, + }, + + /// Answer records are being built. + /// + /// The message already contains zero or more DNS answer records. If + /// there is a last DNS record, its start position is unknown, so it + /// cannot be modified or removed. + Answers, + + /// An answer record is being built. + /// + /// The message contains one or more DNS answer records. The last record + /// can be modified or truncated. + MidAnswer { + /// The offset of the record name. + /// + /// The offset is measured from the start of the message contents. + name: u16, + + /// The offset of the record data. + /// + /// The offset is measured from the start of the message contents. + data: u16, + }, + + /// Authority records are being built. + /// + /// The message already contains zero or more DNS authority records. If + /// there is a last DNS record, its start position is unknown, so it + /// cannot be modified or removed. + Authorities, + + /// An authority record is being built. + /// + /// The message contains one or more DNS authority records. The last + /// record can be modified or truncated. + MidAuthority { + /// The offset of the record name. + /// + /// The offset is measured from the start of the message contents. + name: u16, + + /// The offset of the record data. + /// + /// The offset is measured from the start of the message contents. + data: u16, + }, + + /// Additional records are being built. + /// + /// The message already contains zero or more DNS additional records. If + /// there is a last DNS record, its start position is unknown, so it + /// cannot be modified or removed. + Additionals, + + /// An additional record is being built. + /// + /// The message contains one or more DNS additional records. The last + /// record can be modified or truncated. + MidAdditional { + /// The offset of the record name. + /// + /// The offset is measured from the start of the message contents. + name: u16, + + /// The offset of the record data. + /// + /// The offset is measured from the start of the message contents. + data: u16, + }, +} + +impl MessageState { + /// The current section index. + /// + /// Questions, answers, authorities, and additionals are mapped to 0, 1, + /// 2, and 3, respectively. + pub const fn section_index(&self) -> u8 { + match self { + Self::Questions | Self::MidQuestion { .. } => 0, + Self::Answers | Self::MidAnswer { .. } => 1, + Self::Authorities | Self::MidAuthority { .. } => 2, + Self::Additionals | Self::MidAdditional { .. } => 3, + } + } +} diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index a79928726..35ec3d145 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -1,45 +1,35 @@ //! Building whole DNS messages. +use core::cell::UnsafeCell; + use crate::new_base::{ - wire::TruncationError, Header, Message, Question, RClass, RType, Record, - TTL, + wire::{ParseBytesByRef, TruncationError}, + Header, Message, Question, RClass, RType, Record, TTL, }; -use super::{BuildIntoMessage, Builder, BuilderContext, RecordBuilder}; +use super::{ + BuildIntoMessage, Builder, BuilderContext, MessageState, QuestionBuilder, + RecordBuilder, +}; //----------- MessageBuilder ------------------------------------------------- /// A builder for a whole DNS message. /// -/// This is subtly different from a regular [`Builder`] -- it does not allow -/// for commits and so can always modify the entire message. It has methods -/// for adding entire questions and records to the message. +/// This is a high-level building interface, offering methods to put together +/// entire questions and records. It directly writes into an allocated buffer +/// (on the stack or the heap). pub struct MessageBuilder<'b> { - /// The underlying [`Builder`]. - /// - /// Its commit point is always 0. - inner: Builder<'b>, + /// The message being constructed. + message: &'b mut Message, + + /// Context for building. + pub(super) context: &'b mut BuilderContext, } //--- Initialization impl<'b> MessageBuilder<'b> { - /// Construct a [`MessageBuilder`] from raw parts. - /// - /// # Safety - /// - /// - `message` and `context` are paired together. - pub unsafe fn from_raw_parts( - message: &'b mut Message, - context: &'b mut BuilderContext, - ) -> Self { - // SAFETY: since 'commit' is 0, no part of the message is immutably - // borrowed; it is thus sound to represent as a mutable borrow. - let inner = - unsafe { Builder::from_raw_parts(message.into(), context, 0) }; - Self { inner } - } - /// Initialize an empty [`MessageBuilder`]. /// /// The message header is left uninitialized. use [`Self::header_mut()`] @@ -53,8 +43,10 @@ impl<'b> MessageBuilder<'b> { buffer: &'b mut [u8], context: &'b mut BuilderContext, ) -> Self { - let inner = Builder::new(buffer, context); - Self { inner } + let message = Message::parse_bytes_by_mut(buffer) + .expect("The caller's buffer is at least 12 bytes big"); + *context = BuilderContext::default(); + Self { message, context } } } @@ -62,21 +54,18 @@ impl<'b> MessageBuilder<'b> { impl<'b> MessageBuilder<'b> { /// The message header. - /// - /// The header can be modified by the builder, and so is only available - /// for a short lifetime. Note that it implements [`Copy`]. pub fn header(&self) -> &Header { - self.inner.header() + &self.message.header } - /// Mutable access to the message header. + /// The message header, mutably. pub fn header_mut(&mut self) -> &mut Header { - self.inner.header_mut() + &mut self.message.header } /// The message built thus far. pub fn message(&self) -> &Message { - self.inner.cur_message() + self.message.slice_to(self.context.size) } /// The message built thus far, mutably. @@ -86,34 +75,26 @@ impl<'b> MessageBuilder<'b> { /// The caller must not modify any compressed names among these bytes. /// This can invalidate name compression state. pub unsafe fn message_mut(&mut self) -> &mut Message { - // SAFETY: Since no bytes are committed, and the rest of the message - // is borrowed mutably for 'self', we can use a mutable reference. - unsafe { self.inner.cur_message_ptr().as_mut() } + self.message.slice_to_mut(self.context.size) } /// The builder context. pub fn context(&self) -> &BuilderContext { - self.inner.context() - } - - /// Decompose this builder into raw parts. - /// - /// This returns the message buffer and the context for this builder. The - /// two are linked, and the builder can be recomposed with - /// [`Self::from_raw_parts()`]. - pub fn into_raw_parts(self) -> (&'b mut Message, &'b mut BuilderContext) { - let (mut message, context, _commit) = self.inner.into_raw_parts(); - // SAFETY: As per 'Builder::into_raw_parts()', the message is borrowed - // mutably for the lifetime 'b. Since the commit point is 0, there is - // no immutably-borrowed content in the message, so it can be turned - // into a regular reference. - (unsafe { message.as_mut() }, context) + self.context } } //--- Interaction impl MessageBuilder<'_> { + /// Reborrow the builder with a shorter lifetime. + pub fn reborrow(&mut self) -> MessageBuilder<'_> { + MessageBuilder { + message: self.message, + context: self.context, + } + } + /// Limit the total message size. /// /// The message will not be allowed to exceed the given size, in bytes. @@ -121,148 +102,124 @@ impl MessageBuilder<'_> { /// or TCP packet size is not considered. If the message already exceeds /// this size, a [`TruncationError`] is returned. /// - /// This size will apply to all builders for this message (including those - /// that delegated to `self`). It will not be automatically revoked if - /// message building fails. - /// /// # Panics /// /// Panics if the given size is less than 12 bytes. pub fn limit_to(&mut self, size: usize) -> Result<(), TruncationError> { - self.inner.limit_to(size) - } - - /// Append a question. - /// - /// # Panics - /// - /// Panics if the message contains any records (as questions must come - /// before all records). - pub fn append_question( + if 12 + self.context.size <= size { + // Move out of 'message' so that the full lifetime is available. + // See the 'replace_with' and 'take_mut' crates. + debug_assert!(size < 12 + self.message.contents.len()); + let message = unsafe { core::ptr::read(&self.message) }; + // NOTE: Precondition checked, will not panic. + let message = message.slice_to_mut(size - 12); + unsafe { core::ptr::write(&mut self.message, message) }; + Ok(()) + } else { + Err(TruncationError) + } + } + + /// Truncate the message. + /// + /// This will remove all message contents and mark it as truncated. + pub fn truncate(&mut self) { + self.message.header.flags = + self.message.header.flags.set_truncated(true); + *self.context = BuilderContext::default(); + } + + /// Obtain a [`Builder`]. + pub(super) fn builder(&mut self, start: usize) -> Builder<'_> { + debug_assert!(start <= self.context.size); + unsafe { + let contents = &mut self.message.contents; + let contents = contents as *mut [u8] as *const UnsafeCell<[u8]>; + Builder::from_raw_parts(&*contents, &mut self.context, start) + } + } + + /// Build a question. + /// + /// If a question is already being built, it will be finished first. If + /// an answer, authority, or additional record has been added, [`None`] is + /// returned instead. + pub fn build_question( &mut self, question: &Question, - ) -> Result<(), TruncationError> - where - N: BuildIntoMessage, - { - // Ensure there are no records present. - assert_eq!(self.header().counts.as_array()[1..], [0, 0, 0]); + ) -> Result>, TruncationError> { + if self.context.state.section_index() > 0 { + // We've progressed into a later section. + return Ok(None); + } - question.build_into_message(self.inner.delegate())?; - self.header_mut().counts.questions += 1; - Ok(()) + self.context.state = MessageState::Questions; + QuestionBuilder::build(self.reborrow(), question).map(Some) } - /// Build an arbitrary record. + /// Build an answer record. /// - /// The record will be added to the specified section (1, 2, or 3, i.e. - /// answers, authorities, and additional records respectively). There - /// must not be any existing records in sections after this one. - pub fn build_record( + /// If a question or answer is already being built, it will be finished + /// first. If an authority or additional record has been added, [`None`] + /// is returned instead. + pub fn build_answer( &mut self, rname: impl BuildIntoMessage, rtype: RType, rclass: RClass, ttl: TTL, - section: u8, - ) -> Result, TruncationError> { - RecordBuilder::new( - self.inner.delegate(), + ) -> Result>, TruncationError> { + if self.context.state.section_index() > 1 { + // We've progressed into a later section. + return Ok(None); + } + + let record = Record { rname, rtype, rclass, ttl, - section, - ) - } - - /// Append an answer record. - /// - /// # Panics - /// - /// Panics if the message contains any authority or additional records. - pub fn append_answer( - &mut self, - record: &Record, - ) -> Result<(), TruncationError> - where - N: BuildIntoMessage, - D: BuildIntoMessage, - { - // Ensure there are no authority or additional records present. - assert_eq!(self.header().counts.as_array()[2..], [0, 0]); + rdata: &[] as &[u8], + }; - record.build_into_message(self.inner.delegate())?; - self.header_mut().counts.answers += 1; - Ok(()) - } - - /// Build an answer record. - /// - /// # Panics - /// - /// Panics if the message contains any authority or additional records. - pub fn build_answer( - &mut self, - rname: impl BuildIntoMessage, - rtype: RType, - rclass: RClass, - ttl: TTL, - ) -> Result, TruncationError> { - self.build_record(rname, rtype, rclass, ttl, 1) - } - - /// Append an authority record. - /// - /// # Panics - /// - /// Panics if the message contains any additional records. - pub fn append_authority( - &mut self, - record: &Record, - ) -> Result<(), TruncationError> - where - N: BuildIntoMessage, - D: BuildIntoMessage, - { - // Ensure there are no additional records present. - assert_eq!(self.header().counts.as_array()[3..], [0]); - - record.build_into_message(self.inner.delegate())?; - self.header_mut().counts.authorities += 1; - Ok(()) + self.context.state = MessageState::Answers; + RecordBuilder::build(self.reborrow(), &record).map(Some) } /// Build an authority record. /// - /// # Panics - /// - /// Panics if the message contains any additional records. + /// If a question, answer, or authority is already being built, it will be + /// finished first. If an additional record has been added, [`None`] is + /// returned instead. pub fn build_authority( &mut self, rname: impl BuildIntoMessage, rtype: RType, rclass: RClass, ttl: TTL, - ) -> Result, TruncationError> { - self.build_record(rname, rtype, rclass, ttl, 2) - } + ) -> Result>, TruncationError> { + if self.context.state.section_index() > 2 { + // We've progressed into a later section. + return Ok(None); + } - /// Append an additional record. - pub fn append_additional( - &mut self, - record: &Record, - ) -> Result<(), TruncationError> - where - N: BuildIntoMessage, - D: BuildIntoMessage, - { - record.build_into_message(self.inner.delegate())?; - self.header_mut().counts.additional += 1; - Ok(()) + let record = Record { + rname, + rtype, + rclass, + ttl, + rdata: &[] as &[u8], + }; + + self.context.state = MessageState::Authorities; + RecordBuilder::build(self.reborrow(), &record).map(Some) } /// Build an additional record. + /// + /// If a question or record is already being built, it will be finished + /// first. Note that it is always possible to add an additional record to + /// a message. pub fn build_additional( &mut self, rname: impl BuildIntoMessage, @@ -270,6 +227,15 @@ impl MessageBuilder<'_> { rclass: RClass, ttl: TTL, ) -> Result, TruncationError> { - self.build_record(rname, rtype, rclass, ttl, 3) + let record = Record { + rname, + rtype, + rclass, + ttl, + rdata: &[] as &[u8], + }; + + self.context.state = MessageState::Additionals; + RecordBuilder::build(self.reborrow(), &record) } } diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index e723ef521..ba62062b6 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -4,8 +4,6 @@ //! but it is not specialized to DNS messages. This module provides that //! specialization within an ergonomic interface. //! -//! # The High-Level Interface -//! //! The core of the high-level interface is [`MessageBuilder`]. It provides //! the most intuitive methods for appending whole questions and records. //! @@ -42,58 +40,25 @@ //! qtype: QType::A, //! qclass: QClass::IN, //! }; -//! builder.append_question(&question).unwrap(); +//! let _ = builder.build_question(&question).unwrap().unwrap(); //! //! // Use the built message. //! let message = builder.message(); //! # let _ = message; //! ``` -//! -//! # The Low-Level Interface -//! -//! [`Builder`] is a powerful low-level interface that can be used to build -//! DNS messages. It implements atomic building and name compression, and is -//! the foundation of [`MessageBuilder`]. -//! -//! The [`Builder`] interface does not know about questions and records; it is -//! only capable of appending simple bytes and compressing domain names. Its -//! access to the message buffer is limited; it can only append, modify, or -//! truncate the message up to a certain point (all data before that point is -//! immutable). Special attention is given to the message header, as it can -//! be modified at any point in the message building process. -//! -//! ``` -//! use domain::new_base::build::{BuilderContext, Builder, BuildIntoMessage}; -//! use domain::new_rdata::A; -//! -//! // Construct a builder for a particular buffer. -//! let mut buffer = [0u8; 20]; -//! let mut context = BuilderContext::default(); -//! let mut builder = Builder::new(&mut buffer, &mut context); -//! -//! // Try appending some raw bytes to the builder. -//! builder.append_bytes(b"hi! ").unwrap(); -//! assert_eq!(builder.appended(), b"hi! "); -//! -//! // Try appending some structured content to the builder. -//! A::from(std::net::Ipv4Addr::new(127, 0, 0, 1)) -//! .build_into_message(builder.delegate()) -//! .unwrap(); -//! assert_eq!(builder.appended(), b"hi! \x7F\x00\x00\x01"); -//! -//! // Finish using the builder. -//! builder.commit(); -//! -//! // Note: the first 12 bytes hold the message header. -//! assert_eq!(&buffer[12..20], b"hi! \x7F\x00\x00\x01"); -//! ``` mod builder; -pub use builder::{Builder, BuilderContext}; +pub use builder::Builder; + +mod context; +pub use context::{BuilderContext, MessageState}; mod message; pub use message::MessageBuilder; +mod question; +pub use question::QuestionBuilder; + mod record; pub use record::RecordBuilder; diff --git a/src/new_base/build/question.rs b/src/new_base/build/question.rs new file mode 100644 index 000000000..7c8c8b1e8 --- /dev/null +++ b/src/new_base/build/question.rs @@ -0,0 +1,106 @@ +//! Building DNS questions. + +use crate::new_base::{ + name::UnparsedName, + parse::ParseMessageBytes, + wire::{ParseBytes, TruncationError}, + QClass, QType, Question, +}; + +use super::{BuildCommitted, BuildIntoMessage, MessageBuilder, MessageState}; + +//----------- QuestionBuilder ------------------------------------------------ + +/// A DNS question builder. +pub struct QuestionBuilder<'b> { + /// The underlying message builder. + builder: MessageBuilder<'b>, + + /// The offset of the question name. + name: u16, +} + +//--- Construction + +impl<'b> QuestionBuilder<'b> { + /// Build a [`Question`]. + /// + /// The provided builder must be empty (i.e. must not have uncommitted + /// content). + pub(super) fn build( + mut builder: MessageBuilder<'b>, + question: &Question, + ) -> Result { + // TODO: Require that the QNAME serialize correctly? + let start = builder.context.size; + question.build_into_message(builder.builder(start))?; + let name = start.try_into().expect("Messages are at most 64KiB"); + builder.context.state = MessageState::MidQuestion { name }; + Ok(Self { builder, name }) + } + + /// Reconstruct a [`QuestionBuilder`] from raw parts. + /// + /// # Safety + /// + /// `builder.message().contents[name..]` must represent a valid + /// [`Question`] in the wire format. + pub unsafe fn from_raw_parts( + builder: MessageBuilder<'b>, + name: u16, + ) -> Self { + Self { builder, name } + } +} + +//--- Inspection + +impl<'b> QuestionBuilder<'b> { + /// The (unparsed) question name. + pub fn qname(&self) -> &UnparsedName { + let contents = &self.builder.message().contents; + let contents = &contents[..contents.len() - 4]; + <&UnparsedName>::parse_message_bytes(contents, self.name.into()) + .expect("The question was serialized correctly") + } + + /// The question type. + pub fn qtype(&self) -> QType { + let contents = &self.builder.message().contents; + QType::parse_bytes(&contents[contents.len() - 4..contents.len() - 2]) + .expect("The question was serialized correctly") + } + + /// The question class. + pub fn qclass(&self) -> QClass { + let contents = &self.builder.message().contents; + QClass::parse_bytes(&contents[contents.len() - 2..]) + .expect("The question was serialized correctly") + } + + /// Deconstruct this [`QuestionBuilder`] into its raw parts. + pub fn into_raw_parts(self) -> (MessageBuilder<'b>, u16) { + (self.builder, self.name) + } +} + +//--- Interaction + +impl<'b> QuestionBuilder<'b> { + /// Commit this question. + /// + /// The builder will be consumed, and the question will be committed so + /// that it can no longer be removed. + pub fn commit(self) -> BuildCommitted { + self.builder.context.state = MessageState::Questions; + BuildCommitted + } + + /// Stop building and remove this question. + /// + /// The builder will be consumed, and the question will be removed. + pub fn cancel(self) { + self.builder.context.size = self.name.into(); + self.builder.context.state = MessageState::Questions; + } +} diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index aac0857c3..e1ef789ab 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -1,176 +1,230 @@ //! Building DNS records. +use core::{mem::ManuallyDrop, ptr}; + use crate::new_base::{ - name::RevName, - wire::{AsBytes, TruncationError}, - Header, Message, RClass, RType, TTL, + name::UnparsedName, + parse::ParseMessageBytes, + wire::{AsBytes, ParseBytes, SizePrefixed, TruncationError}, + RClass, RType, Record, TTL, }; -use super::{BuildCommitted, BuildIntoMessage, Builder}; +use super::{ + BuildCommitted, BuildIntoMessage, Builder, MessageBuilder, MessageState, +}; -//----------- RecordBuilder -------------------------------------------------- +//----------- RecordBuilder ------------------------------------------------ -/// A builder for a DNS record. -/// -/// This is used to incrementally build the data for a DNS record. It can be -/// constructed using [`MessageBuilder::build_answer()`] etc. -/// -/// [`MessageBuilder::build_answer()`]: super::MessageBuilder::build_answer() +/// A DNS record builder. pub struct RecordBuilder<'b> { - /// The underlying [`Builder`]. - /// - /// Its commit point lies at the beginning of the record. - inner: Builder<'b>, + /// The underlying message builder. + builder: MessageBuilder<'b>, - /// The position of the record data. - /// - /// This is an offset from the message contents. - start: usize, + /// The offset of the record name. + name: u16, - /// The section the record is a part of. - /// - /// The appropriate section count will be incremented on completion. - section: u8, + /// The offset of the record data. + data: u16, } -//--- Initialization +//--- Construction impl<'b> RecordBuilder<'b> { - /// Construct a [`RecordBuilder`] from raw parts. + /// Build a [`Record`]. + /// + /// The provided builder must be empty (i.e. must not have uncommitted + /// content). + pub(super) fn build( + mut builder: MessageBuilder<'b>, + record: &Record, + ) -> Result + where + N: BuildIntoMessage, + D: BuildIntoMessage, + { + // Build the record and remember important positions. + let start = builder.context.size; + let (name, data) = { + let name = start.try_into().expect("Messages are at most 64KiB"); + let mut b = builder.builder(start); + record.rname.build_into_message(b.delegate())?; + b.append_bytes(&record.rtype.as_bytes())?; + b.append_bytes(&record.rclass.as_bytes())?; + b.append_bytes(&record.ttl.as_bytes())?; + let size = b.context().size; + SizePrefixed::new(&record.rdata) + .build_into_message(b.delegate())?; + let data = + (size + 2).try_into().expect("Messages are at most 64KiB"); + (name, data) + }; + + // Update the message state. + match builder.context.state { + ref mut state @ MessageState::Answers => { + *state = MessageState::MidAnswer { name, data }; + } + + ref mut state @ MessageState::Authorities => { + *state = MessageState::MidAuthority { name, data }; + } + + ref mut state @ MessageState::Additionals => { + *state = MessageState::MidAdditional { name, data }; + } + + _ => unreachable!(), + } + + Ok(Self { + builder, + name, + data, + }) + } + + /// Reconstruct a [`RecordBuilder`] from raw parts. /// /// # Safety /// - /// - `builder`, `start`, and `section` are paired together. + /// `builder.message().contents[name..]` must represent a valid + /// [`Record`] in the wire format. `contents[data..]` must represent the + /// record data (i.e. immediately after the record data size field). pub unsafe fn from_raw_parts( - builder: Builder<'b>, - start: usize, - section: u8, + builder: MessageBuilder<'b>, + name: u16, + data: u16, ) -> Self { Self { - inner: builder, - start, - section, + builder, + name, + data, } } - - /// Initialize a new [`RecordBuilder`]. - /// - /// A new record with the given name, type, and class will be created. - /// The returned builder can be used to add data for the record. - /// - /// The count for the specified section (1, 2, or 3, i.e. answers, - /// authorities, and additional records respectively) will be incremented - /// when the builder finishes successfully. - pub fn new( - mut builder: Builder<'b>, - rname: impl BuildIntoMessage, - rtype: RType, - rclass: RClass, - ttl: TTL, - section: u8, - ) -> Result { - debug_assert_eq!(builder.appended(), &[] as &[u8]); - debug_assert!((1..4).contains(§ion)); - - assert!(builder - .header() - .counts - .as_array() - .iter() - .skip(1 + section as usize) - .all(|&c| c == 0)); - - // Build the record header. - rname.build_into_message(builder.delegate())?; - builder.append_bytes(rtype.as_bytes())?; - builder.append_bytes(rclass.as_bytes())?; - builder.append_bytes(ttl.as_bytes())?; - builder.append_bytes(&0u16.to_be_bytes())?; - let start = builder.appended().len(); - - // Set up the builder. - Ok(Self { - inner: builder, - start, - section, - }) - } } //--- Inspection impl<'b> RecordBuilder<'b> { - /// The message header. - pub fn header(&self) -> &Header { - self.inner.header() + /// The (unparsed) record name. + pub fn rname(&self) -> &UnparsedName { + let contents = &self.builder.message().contents; + let contents = &contents[..contents.len() - 4]; + <&UnparsedName>::parse_message_bytes(contents, self.name.into()) + .expect("The record was serialized correctly") } - /// The message without this record. - pub fn message(&self) -> &Message { - self.inner.message() + /// The record type. + pub fn rtype(&self) -> RType { + let contents = &self.builder.message().contents; + let contents = &contents[usize::from(self.data) - 8..]; + RType::parse_bytes(&contents[0..2]) + .expect("The record was serialized correctly") } - /// The record data appended thus far. - pub fn data(&self) -> &[u8] { - &self.inner.appended()[self.start..] + /// The record class. + pub fn rclass(&self) -> RClass { + let contents = &self.builder.message().contents; + let contents = &contents[usize::from(self.data) - 8..]; + RClass::parse_bytes(&contents[2..4]) + .expect("The record was serialized correctly") } - /// Decompose this builder into raw parts. - /// - /// This returns the underlying builder, the offset of the record data in - /// the record, and the section number for this record (1, 2, or 3). The - /// builder can be recomposed with [`Self::from_raw_parts()`]. - pub fn into_raw_parts(self) -> (Builder<'b>, usize, u8) { - (self.inner, self.start, self.section) + /// The TTL. + pub fn ttl(&self) -> TTL { + let contents = &self.builder.message().contents; + let contents = &contents[usize::from(self.data) - 8..]; + TTL::parse_bytes(&contents[4..8]) + .expect("The record was serialized correctly") + } + + /// The record data built thus far. + pub fn rdata(&self) -> &[u8] { + &self.builder.message().contents[usize::from(self.data)..] + } + + /// Deconstruct this [`RecordBuilder`] into its raw parts. + pub fn into_raw_parts(self) -> (MessageBuilder<'b>, u16, u16) { + let (name, data) = (self.name, self.data); + let this = ManuallyDrop::new(self); + let this = (&*this) as *const Self; + // SAFETY: 'this' is a valid object that can be moved out of. + let builder = unsafe { ptr::read(ptr::addr_of!((*this).builder)) }; + (builder, name, data) } } //--- Interaction -impl RecordBuilder<'_> { - /// Finish the record. +impl<'b> RecordBuilder<'b> { + /// Commit this record. /// - /// The respective section count will be incremented. The builder will be - /// consumed and the record will be committed. - pub fn finish(mut self) -> BuildCommitted { - // Increment the appropriate section count. - self.inner.header_mut().counts.as_array_mut() - [self.section as usize] += 1; - - // Set the record data length. - let size = self.inner.appended().len() - self.start; - let size = u16::try_from(size) - .expect("Record data must be smaller than 64KiB"); - // SAFETY: The record data size is not part of a compressed name. - let appended = unsafe { self.inner.appended_mut() }; - appended[self.start - 2..self.start] - .copy_from_slice(&size.to_be_bytes()); - - self.inner.commit() + /// The builder will be consumed, and the record will be committed so that + /// it can no longer be removed. + pub fn commit(self) -> BuildCommitted { + match self.builder.context.state { + ref mut state @ MessageState::MidAnswer { .. } => { + *state = MessageState::Answers; + } + + ref mut state @ MessageState::MidAuthority { .. } => { + *state = MessageState::Authorities; + } + + ref mut state @ MessageState::MidAdditional { .. } => { + *state = MessageState::Additionals; + } + + _ => unreachable!(), + } + + // NOTE: The record data size will be fixed on drop. + BuildCommitted } - /// Delegate to a new builder. + /// Stop building and remove this record. /// - /// Any content committed by the builder will be added as record data. - pub fn delegate(&mut self) -> Builder<'_> { - self.inner.delegate() + /// The builder will be consumed, and the record will be removed. + pub fn cancel(self) { + self.builder.context.size = self.name.into(); + match self.builder.context.state { + ref mut state @ MessageState::MidAnswer { .. } => { + *state = MessageState::Answers; + } + + ref mut state @ MessageState::MidAuthority { .. } => { + *state = MessageState::Authorities; + } + + ref mut state @ MessageState::MidAdditional { .. } => { + *state = MessageState::Additionals; + } + + _ => unreachable!(), + } + + // NOTE: The drop glue is a no-op. } - /// Append some bytes. - /// - /// No name compression will be performed. - pub fn append_bytes( - &mut self, - bytes: &[u8], - ) -> Result<(), TruncationError> { - self.inner.append_bytes(bytes) + /// Delegate further building of the record data to a new [`Builder`]. + pub fn delegate(&mut self) -> Builder<'_> { + let offset = self.builder.context.size; + self.builder.builder(offset) } +} - /// Compress and append a domain name. - pub fn append_name( - &mut self, - name: &RevName, - ) -> Result<(), TruncationError> { - self.inner.append_name(name) +//--- Drop + +impl Drop for RecordBuilder<'_> { + fn drop(&mut self) { + // Fixup the record data size so the overall message builder is valid. + let size = self.builder.context.size as u16; + if self.data <= size { + // SAFETY: Only the record data size field is being modified. + let message = unsafe { self.builder.message_mut() }; + let data = usize::from(self.data); + message.contents[data - 2..data] + .copy_from_slice(&size.to_be_bytes()); + } } } diff --git a/src/new_base/wire/size_prefixed.rs b/src/new_base/wire/size_prefixed.rs index 5ac9effa9..9053c6431 100644 --- a/src/new_base/wire/size_prefixed.rs +++ b/src/new_base/wire/size_prefixed.rs @@ -248,13 +248,13 @@ impl BuildIntoMessage for SizePrefixed { &self, mut builder: build::Builder<'_>, ) -> BuildResult { - assert_eq!(builder.appended(), &[] as &[u8]); + assert_eq!(builder.uncommitted(), &[] as &[u8]); builder.append_bytes(&0u16.to_be_bytes())?; self.data.build_into_message(builder.delegate())?; - let size = builder.appended().len() - 2; + let size = builder.uncommitted().len() - 2; let size = u16::try_from(size).expect("the data never exceeds 64KiB"); // SAFETY: A 'U16' is being modified, not a domain name. - let size_buf = unsafe { &mut builder.appended_mut()[0..2] }; + let size_buf = unsafe { &mut builder.uncommitted_mut()[0..2] }; size_buf.copy_from_slice(&size.to_be_bytes()); Ok(builder.commit()) } From 41db42a76c9c8eff277ede96619bb37ab23269f8 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 12:05:44 +0100 Subject: [PATCH 098/167] [new_base/build/message] Add resumption methods --- src/new_base/build/context.rs | 2 +- src/new_base/build/message.rs | 75 ++++++++++++++++++++++++++++++++++ src/new_base/build/question.rs | 5 +++ src/new_base/build/record.rs | 5 +++ 4 files changed, 86 insertions(+), 1 deletion(-) diff --git a/src/new_base/build/context.rs b/src/new_base/build/context.rs index 2f7f43da1..bd423b4de 100644 --- a/src/new_base/build/context.rs +++ b/src/new_base/build/context.rs @@ -26,7 +26,7 @@ pub struct BuilderContext { /// additionals. [`MessageState`] remembers the start position of the last /// question or record in the message, allowing it to be modifying or removed /// (for additional flexibility in the building process). -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, PartialEq, Eq)] pub enum MessageState { /// Questions are being built. /// diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 35ec3d145..288b4eefb 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -157,6 +157,24 @@ impl MessageBuilder<'_> { QuestionBuilder::build(self.reborrow(), question).map(Some) } + /// Resume building a question. + /// + /// If a question was built (using [`build_question()`]) but the returned + /// builder was neither committed nor canceled, the question builder will + /// be recovered and returned. + /// + /// [`build_question()`]: Self::build_question() + pub fn resume_question(&mut self) -> Option> { + let MessageState::MidQuestion { name } = self.context.state else { + return None; + }; + + // SAFETY: 'self.context.state' is synchronized with the message. + Some(unsafe { + QuestionBuilder::from_raw_parts(self.reborrow(), name) + }) + } + /// Build an answer record. /// /// If a question or answer is already being built, it will be finished @@ -186,6 +204,25 @@ impl MessageBuilder<'_> { RecordBuilder::build(self.reborrow(), &record).map(Some) } + /// Resume building an answer record. + /// + /// If an answer record was built (using [`build_answer()`]) but the + /// returned builder was neither committed nor canceled, the record + /// builder will be recovered and returned. + /// + /// [`build_answer()`]: Self::build_answer() + pub fn resume_answer(&mut self) -> Option> { + let MessageState::MidAnswer { name, data } = self.context.state + else { + return None; + }; + + // SAFETY: 'self.context.state' is synchronized with the message. + Some(unsafe { + RecordBuilder::from_raw_parts(self.reborrow(), name, data) + }) + } + /// Build an authority record. /// /// If a question, answer, or authority is already being built, it will be @@ -215,6 +252,25 @@ impl MessageBuilder<'_> { RecordBuilder::build(self.reborrow(), &record).map(Some) } + /// Resume building an authority record. + /// + /// If an authority record was built (using [`build_authority()`]) but + /// the returned builder was neither committed nor canceled, the record + /// builder will be recovered and returned. + /// + /// [`build_authority()`]: Self::build_authority() + pub fn resume_authority(&mut self) -> Option> { + let MessageState::MidAuthority { name, data } = self.context.state + else { + return None; + }; + + // SAFETY: 'self.context.state' is synchronized with the message. + Some(unsafe { + RecordBuilder::from_raw_parts(self.reborrow(), name, data) + }) + } + /// Build an additional record. /// /// If a question or record is already being built, it will be finished @@ -238,4 +294,23 @@ impl MessageBuilder<'_> { self.context.state = MessageState::Additionals; RecordBuilder::build(self.reborrow(), &record) } + + /// Resume building an additional record. + /// + /// If an additional record was built (using [`build_additional()`]) but + /// the returned builder was neither committed nor canceled, the record + /// builder will be recovered and returned. + /// + /// [`build_additional()`]: Self::build_additional() + pub fn resume_additional(&mut self) -> Option> { + let MessageState::MidAdditional { name, data } = self.context.state + else { + return None; + }; + + // SAFETY: 'self.context.state' is synchronized with the message. + Some(unsafe { + RecordBuilder::from_raw_parts(self.reborrow(), name, data) + }) + } } diff --git a/src/new_base/build/question.rs b/src/new_base/build/question.rs index 7c8c8b1e8..d16921eff 100644 --- a/src/new_base/build/question.rs +++ b/src/new_base/build/question.rs @@ -12,6 +12,11 @@ use super::{BuildCommitted, BuildIntoMessage, MessageBuilder, MessageState}; //----------- QuestionBuilder ------------------------------------------------ /// A DNS question builder. +/// +/// A [`QuestionBuilder`] provides control over a DNS question that has been +/// appended to a message (using a [`MessageBuilder`]). It can be used to +/// inspect the question's fields, to replace it with a new question, and to +/// commit (finish building) or cancel (remove) the question. pub struct QuestionBuilder<'b> { /// The underlying message builder. builder: MessageBuilder<'b>, diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index e1ef789ab..66bf96a09 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -16,6 +16,11 @@ use super::{ //----------- RecordBuilder ------------------------------------------------ /// A DNS record builder. +/// +/// A [`RecordBuilder`] provides access to a record that has been appended to +/// a DNS message (using a [`MessageBuilder`]). It can be used to inspect the +/// record, to (re)write the record data, and to commit (finish building) or +/// cancel (remove) the record. pub struct RecordBuilder<'b> { /// The underlying message builder. builder: MessageBuilder<'b>, From 2fcd114c2f26f8cfc4eeb06c7300fca691615a9d Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 12:15:57 +0100 Subject: [PATCH 099/167] [new_base/build] Test 'MessageBuilder' and fix some bugs --- src/new_base/build/message.rs | 140 +++++++++++++++++++++++++++++++++ src/new_base/build/question.rs | 2 +- src/new_base/build/record.rs | 11 ++- src/new_base/record.rs | 32 ++++++++ 4 files changed, 180 insertions(+), 5 deletions(-) diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 288b4eefb..28967d448 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -314,3 +314,143 @@ impl MessageBuilder<'_> { }) } } + +//============ Tests ========================================================= + +#[cfg(test)] +mod test { + use crate::{ + new_base::{ + build::{BuildIntoMessage, BuilderContext, MessageState}, + name::RevName, + QClass, QType, Question, RClass, RType, TTL, + }, + new_rdata::A, + }; + + use super::MessageBuilder; + + const WWW_EXAMPLE_ORG: &RevName = unsafe { + RevName::from_bytes_unchecked(b"\x00\x03org\x07example\x03www") + }; + + #[test] + fn new() { + let mut buffer = [0u8; 12]; + let mut context = BuilderContext::default(); + + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + assert_eq!(&builder.message().contents, &[] as &[u8]); + assert_eq!(unsafe { &builder.message_mut().contents }, &[] as &[u8]); + assert_eq!(builder.context().size, 0); + assert_eq!(builder.context().state, MessageState::Questions); + } + + #[test] + fn build_question() { + let mut buffer = [0u8; 33]; + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + let question = Question { + qname: WWW_EXAMPLE_ORG, + qtype: QType::A, + qclass: QClass::IN, + }; + let qb = builder.build_question(&question).unwrap().unwrap(); + + assert_eq!(qb.qname().as_bytes(), b"\x03www\x07example\x03org\x00"); + assert_eq!(qb.qtype(), question.qtype); + assert_eq!(qb.qclass(), question.qclass); + + let state = MessageState::MidQuestion { name: 0 }; + assert_eq!(builder.context().state, state); + let contents = b"\x03www\x07example\x03org\x00\x00\x01\x00\x01"; + assert_eq!(&builder.message().contents, contents); + } + + #[test] + fn resume_question() { + let mut buffer = [0u8; 33]; + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + let question = Question { + qname: WWW_EXAMPLE_ORG, + qtype: QType::A, + qclass: QClass::IN, + }; + let _ = builder.build_question(&question).unwrap().unwrap(); + + let qb = builder.resume_question().unwrap(); + assert_eq!(qb.qname().as_bytes(), b"\x03www\x07example\x03org\x00"); + assert_eq!(qb.qtype(), question.qtype); + assert_eq!(qb.qclass(), question.qclass); + } + + #[test] + fn build_record() { + let mut buffer = [0u8; 43]; + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + { + let mut rb = builder + .build_answer( + WWW_EXAMPLE_ORG, + RType::A, + RClass::IN, + TTL::from(42), + ) + .unwrap() + .unwrap(); + + assert_eq!( + rb.rname().as_bytes(), + b"\x03www\x07example\x03org\x00" + ); + assert_eq!(rb.rtype(), RType::A); + assert_eq!(rb.rclass(), RClass::IN); + assert_eq!(rb.ttl(), TTL::from(42)); + assert_eq!(rb.rdata(), b""); + + assert!(rb.delegate().append_bytes(&[0u8; 5]).is_err()); + + let rdata = A { + octets: [127, 0, 0, 1], + }; + rdata.build_into_message(rb.delegate()).unwrap(); + assert_eq!(rb.rdata(), b"\x7F\x00\x00\x01"); + } + + let state = MessageState::MidAnswer { name: 0, data: 27 }; + assert_eq!(builder.context().state, state); + let contents = b"\x03www\x07example\x03org\x00\x00\x01\x00\x01\x00\x00\x00\x2A\x00\x04\x7F\x00\x00\x01"; + assert_eq!(&builder.message().contents, contents.as_slice()); + } + + #[test] + fn resume_record() { + let mut buffer = [0u8; 39]; + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + let _ = builder + .build_answer( + WWW_EXAMPLE_ORG, + RType::A, + RClass::IN, + TTL::from(42), + ) + .unwrap() + .unwrap(); + + let rb = builder.resume_answer().unwrap(); + assert_eq!(rb.rname().as_bytes(), b"\x03www\x07example\x03org\x00"); + assert_eq!(rb.rtype(), RType::A); + assert_eq!(rb.rclass(), RClass::IN); + assert_eq!(rb.ttl(), TTL::from(42)); + assert_eq!(rb.rdata(), b""); + } +} diff --git a/src/new_base/build/question.rs b/src/new_base/build/question.rs index d16921eff..72ae6bac0 100644 --- a/src/new_base/build/question.rs +++ b/src/new_base/build/question.rs @@ -64,7 +64,7 @@ impl<'b> QuestionBuilder<'b> { /// The (unparsed) question name. pub fn qname(&self) -> &UnparsedName { let contents = &self.builder.message().contents; - let contents = &contents[..contents.len() - 4]; + let contents = &contents[usize::from(self.name)..contents.len() - 4]; <&UnparsedName>::parse_message_bytes(contents, self.name.into()) .expect("The question was serialized correctly") } diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index 66bf96a09..707539593 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -61,6 +61,7 @@ impl<'b> RecordBuilder<'b> { .build_into_message(b.delegate())?; let data = (size + 2).try_into().expect("Messages are at most 64KiB"); + b.commit(); (name, data) }; @@ -114,7 +115,8 @@ impl<'b> RecordBuilder<'b> { /// The (unparsed) record name. pub fn rname(&self) -> &UnparsedName { let contents = &self.builder.message().contents; - let contents = &contents[..contents.len() - 4]; + let contents = + &contents[usize::from(self.name)..usize::from(self.data) - 10]; <&UnparsedName>::parse_message_bytes(contents, self.name.into()) .expect("The record was serialized correctly") } @@ -122,7 +124,7 @@ impl<'b> RecordBuilder<'b> { /// The record type. pub fn rtype(&self) -> RType { let contents = &self.builder.message().contents; - let contents = &contents[usize::from(self.data) - 8..]; + let contents = &contents[usize::from(self.data) - 10..]; RType::parse_bytes(&contents[0..2]) .expect("The record was serialized correctly") } @@ -130,7 +132,7 @@ impl<'b> RecordBuilder<'b> { /// The record class. pub fn rclass(&self) -> RClass { let contents = &self.builder.message().contents; - let contents = &contents[usize::from(self.data) - 8..]; + let contents = &contents[usize::from(self.data) - 10..]; RClass::parse_bytes(&contents[2..4]) .expect("The record was serialized correctly") } @@ -138,7 +140,7 @@ impl<'b> RecordBuilder<'b> { /// The TTL. pub fn ttl(&self) -> TTL { let contents = &self.builder.message().contents; - let contents = &contents[usize::from(self.data) - 8..]; + let contents = &contents[usize::from(self.data) - 10..]; TTL::parse_bytes(&contents[4..8]) .expect("The record was serialized correctly") } @@ -228,6 +230,7 @@ impl Drop for RecordBuilder<'_> { // SAFETY: Only the record data size field is being modified. let message = unsafe { self.builder.message_mut() }; let data = usize::from(self.data); + let size = size - self.data; message.contents[data - 2..data] .copy_from_slice(&size.to_be_bytes()); } diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 0c0e6fa3c..f3663df60 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -273,6 +273,22 @@ pub struct RClass { pub code: U16, } +//--- Associated Constants + +impl RClass { + const fn new(value: u16) -> Self { + Self { + code: U16::new(value), + } + } + + /// The Internet class. + pub const IN: Self = Self::new(1); + + /// The CHAOS class. + pub const CH: Self = Self::new(3); +} + //----------- TTL ------------------------------------------------------------ /// How long a record can be cached. @@ -298,6 +314,22 @@ pub struct TTL { pub value: U32, } +//--- Conversion to and from integers + +impl From for TTL { + fn from(value: u32) -> Self { + Self { + value: U32::new(value), + } + } +} + +impl From for u32 { + fn from(value: TTL) -> Self { + value.value.get() + } +} + //----------- ParseRecordData ------------------------------------------------ /// Parsing DNS record data. From 074486c8be67994ea67506419f2c8ac552c43628 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 12:40:39 +0100 Subject: [PATCH 100/167] [new_base/build] Track section counts while building --- src/new_base/build/context.rs | 47 ++++++++++ src/new_base/build/message.rs | 158 ++++++++++++++++++--------------- src/new_base/build/mod.rs | 19 +++- src/new_base/build/question.rs | 1 + src/new_base/build/record.rs | 35 ++------ 5 files changed, 155 insertions(+), 105 deletions(-) diff --git a/src/new_base/build/context.rs b/src/new_base/build/context.rs index bd423b4de..112330123 100644 --- a/src/new_base/build/context.rs +++ b/src/new_base/build/context.rs @@ -2,6 +2,8 @@ //----------- BuilderContext ------------------------------------------------- +use crate::new_base::SectionCounts; + /// Context for building a DNS message. /// /// This type holds auxiliary information necessary for building DNS messages, @@ -132,4 +134,49 @@ impl MessageState { Self::Additionals | Self::MidAdditional { .. } => 3, } } + + /// Whether a question or record is being built. + pub const fn mid_component(&self) -> bool { + match self { + Self::MidQuestion { .. } => true, + Self::MidAnswer { .. } => true, + Self::MidAuthority { .. } => true, + Self::MidAdditional { .. } => true, + _ => false, + } + } + + /// Commit a question or record and update the section counts. + pub fn commit(&mut self, counts: &mut SectionCounts) { + match self { + Self::MidQuestion { .. } => { + counts.questions += 1; + *self = Self::Questions; + } + Self::MidAnswer { .. } => { + counts.answers += 1; + *self = Self::Answers; + } + Self::MidAuthority { .. } => { + counts.authorities += 1; + *self = Self::Authorities; + } + Self::MidAdditional { .. } => { + counts.additional += 1; + *self = Self::Additionals; + } + _ => {} + } + } + + /// Cancel a question or record. + pub fn cancel(&mut self) { + match self { + Self::MidQuestion { .. } => *self = Self::Questions, + Self::MidAnswer { .. } => *self = Self::Answers, + Self::MidAuthority { .. } => *self = Self::Authorities, + Self::MidAdditional { .. } => *self = Self::Additionals, + _ => {} + } + } } diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 28967d448..3101ecd76 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -4,7 +4,7 @@ use core::cell::UnsafeCell; use crate::new_base::{ wire::{ParseBytesByRef, TruncationError}, - Header, Message, Question, RClass, RType, Record, TTL, + Header, Message, Question, Record, }; use super::{ @@ -21,7 +21,7 @@ use super::{ /// (on the stack or the heap). pub struct MessageBuilder<'b> { /// The message being constructed. - message: &'b mut Message, + pub(super) message: &'b mut Message, /// Context for building. pub(super) context: &'b mut BuilderContext, @@ -148,12 +148,18 @@ impl MessageBuilder<'_> { &mut self, question: &Question, ) -> Result>, TruncationError> { - if self.context.state.section_index() > 0 { + let state = &mut self.context.state; + if state.section_index() > 0 { // We've progressed into a later section. return Ok(None); } - self.context.state = MessageState::Questions; + if state.mid_component() { + let index = state.section_index() as usize; + self.message.header.counts.as_array_mut()[index] += 1; + } + + *state = MessageState::Questions; QuestionBuilder::build(self.reborrow(), question).map(Some) } @@ -180,28 +186,23 @@ impl MessageBuilder<'_> { /// If a question or answer is already being built, it will be finished /// first. If an authority or additional record has been added, [`None`] /// is returned instead. - pub fn build_answer( + pub fn build_answer( &mut self, - rname: impl BuildIntoMessage, - rtype: RType, - rclass: RClass, - ttl: TTL, + record: &Record, ) -> Result>, TruncationError> { - if self.context.state.section_index() > 1 { + let state = &mut self.context.state; + if state.section_index() > 1 { // We've progressed into a later section. return Ok(None); } - let record = Record { - rname, - rtype, - rclass, - ttl, - rdata: &[] as &[u8], - }; + if state.mid_component() { + let index = state.section_index() as usize; + self.message.header.counts.as_array_mut()[index] += 1; + } - self.context.state = MessageState::Answers; - RecordBuilder::build(self.reborrow(), &record).map(Some) + *state = MessageState::Answers; + RecordBuilder::build(self.reborrow(), record).map(Some) } /// Resume building an answer record. @@ -228,28 +229,23 @@ impl MessageBuilder<'_> { /// If a question, answer, or authority is already being built, it will be /// finished first. If an additional record has been added, [`None`] is /// returned instead. - pub fn build_authority( + pub fn build_authority( &mut self, - rname: impl BuildIntoMessage, - rtype: RType, - rclass: RClass, - ttl: TTL, + record: &Record, ) -> Result>, TruncationError> { - if self.context.state.section_index() > 2 { + let state = &mut self.context.state; + if state.section_index() > 2 { // We've progressed into a later section. return Ok(None); } - let record = Record { - rname, - rtype, - rclass, - ttl, - rdata: &[] as &[u8], - }; + if state.mid_component() { + let index = state.section_index() as usize; + self.message.header.counts.as_array_mut()[index] += 1; + } - self.context.state = MessageState::Authorities; - RecordBuilder::build(self.reborrow(), &record).map(Some) + *state = MessageState::Authorities; + RecordBuilder::build(self.reborrow(), record).map(Some) } /// Resume building an authority record. @@ -276,23 +272,18 @@ impl MessageBuilder<'_> { /// If a question or record is already being built, it will be finished /// first. Note that it is always possible to add an additional record to /// a message. - pub fn build_additional( + pub fn build_additional( &mut self, - rname: impl BuildIntoMessage, - rtype: RType, - rclass: RClass, - ttl: TTL, + record: &Record, ) -> Result, TruncationError> { - let record = Record { - rname, - rtype, - rclass, - ttl, - rdata: &[] as &[u8], - }; + let state = &mut self.context.state; + if state.mid_component() { + let index = state.section_index() as usize; + self.message.header.counts.as_array_mut()[index] += 1; + } - self.context.state = MessageState::Additionals; - RecordBuilder::build(self.reborrow(), &record) + *state = MessageState::Additionals; + RecordBuilder::build(self.reborrow(), record) } /// Resume building an additional record. @@ -323,7 +314,9 @@ mod test { new_base::{ build::{BuildIntoMessage, BuilderContext, MessageState}, name::RevName, - QClass, QType, Question, RClass, RType, TTL, + wire::U16, + QClass, QType, Question, RClass, RType, Record, SectionCounts, + TTL, }, new_rdata::A, }; @@ -366,6 +359,7 @@ mod test { let state = MessageState::MidQuestion { name: 0 }; assert_eq!(builder.context().state, state); + assert_eq!(builder.message().header.counts, SectionCounts::default()); let contents = b"\x03www\x07example\x03org\x00\x00\x01\x00\x01"; assert_eq!(&builder.message().contents, contents); } @@ -387,6 +381,15 @@ mod test { assert_eq!(qb.qname().as_bytes(), b"\x03www\x07example\x03org\x00"); assert_eq!(qb.qtype(), question.qtype); assert_eq!(qb.qclass(), question.qclass); + + qb.commit(); + assert_eq!( + builder.message().header.counts, + SectionCounts { + questions: U16::new(1), + ..Default::default() + } + ); } #[test] @@ -395,24 +398,24 @@ mod test { let mut context = BuilderContext::default(); let mut builder = MessageBuilder::new(&mut buffer, &mut context); + let record = Record { + rname: WWW_EXAMPLE_ORG, + rtype: RType::A, + rclass: RClass::IN, + ttl: TTL::from(42), + rdata: b"", + }; + { - let mut rb = builder - .build_answer( - WWW_EXAMPLE_ORG, - RType::A, - RClass::IN, - TTL::from(42), - ) - .unwrap() - .unwrap(); + let mut rb = builder.build_answer(&record).unwrap().unwrap(); assert_eq!( rb.rname().as_bytes(), b"\x03www\x07example\x03org\x00" ); - assert_eq!(rb.rtype(), RType::A); - assert_eq!(rb.rclass(), RClass::IN); - assert_eq!(rb.ttl(), TTL::from(42)); + assert_eq!(rb.rtype(), record.rtype); + assert_eq!(rb.rclass(), record.rclass); + assert_eq!(rb.ttl(), record.ttl); assert_eq!(rb.rdata(), b""); assert!(rb.delegate().append_bytes(&[0u8; 5]).is_err()); @@ -426,6 +429,7 @@ mod test { let state = MessageState::MidAnswer { name: 0, data: 27 }; assert_eq!(builder.context().state, state); + assert_eq!(builder.message().header.counts, SectionCounts::default()); let contents = b"\x03www\x07example\x03org\x00\x00\x01\x00\x01\x00\x00\x00\x2A\x00\x04\x7F\x00\x00\x01"; assert_eq!(&builder.message().contents, contents.as_slice()); } @@ -436,21 +440,29 @@ mod test { let mut context = BuilderContext::default(); let mut builder = MessageBuilder::new(&mut buffer, &mut context); - let _ = builder - .build_answer( - WWW_EXAMPLE_ORG, - RType::A, - RClass::IN, - TTL::from(42), - ) - .unwrap() - .unwrap(); + let record = Record { + rname: WWW_EXAMPLE_ORG, + rtype: RType::A, + rclass: RClass::IN, + ttl: TTL::from(42), + rdata: b"", + }; + let _ = builder.build_answer(&record).unwrap().unwrap(); let rb = builder.resume_answer().unwrap(); assert_eq!(rb.rname().as_bytes(), b"\x03www\x07example\x03org\x00"); - assert_eq!(rb.rtype(), RType::A); - assert_eq!(rb.rclass(), RClass::IN); - assert_eq!(rb.ttl(), TTL::from(42)); + assert_eq!(rb.rtype(), record.rtype); + assert_eq!(rb.rclass(), record.rclass); + assert_eq!(rb.ttl(), record.ttl); assert_eq!(rb.rdata(), b""); + + rb.commit(); + assert_eq!( + builder.message().header.counts, + SectionCounts { + answers: U16::new(1), + ..Default::default() + } + ); } } diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index ba62062b6..871d20571 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -81,13 +81,28 @@ impl BuildIntoMessage for &T { } } -impl BuildIntoMessage for [u8] { +impl BuildIntoMessage for u8 { fn build_into_message(&self, mut builder: Builder<'_>) -> BuildResult { - builder.append_bytes(self)?; + builder.append_bytes(&[*self])?; Ok(builder.commit()) } } +impl BuildIntoMessage for [T] { + fn build_into_message(&self, mut builder: Builder<'_>) -> BuildResult { + for elem in self { + elem.build_into_message(builder.delegate())?; + } + Ok(builder.commit()) + } +} + +impl BuildIntoMessage for [T; N] { + fn build_into_message(&self, builder: Builder<'_>) -> BuildResult { + self.as_slice().build_into_message(builder) + } +} + //----------- BuildResult ---------------------------------------------------- /// The result of building into a DNS message. diff --git a/src/new_base/build/question.rs b/src/new_base/build/question.rs index 72ae6bac0..680b828fd 100644 --- a/src/new_base/build/question.rs +++ b/src/new_base/build/question.rs @@ -98,6 +98,7 @@ impl<'b> QuestionBuilder<'b> { /// that it can no longer be removed. pub fn commit(self) -> BuildCommitted { self.builder.context.state = MessageState::Questions; + self.builder.message.header.counts.questions += 1; BuildCommitted } diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index 707539593..6a27826c3 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -169,21 +169,10 @@ impl<'b> RecordBuilder<'b> { /// The builder will be consumed, and the record will be committed so that /// it can no longer be removed. pub fn commit(self) -> BuildCommitted { - match self.builder.context.state { - ref mut state @ MessageState::MidAnswer { .. } => { - *state = MessageState::Answers; - } - - ref mut state @ MessageState::MidAuthority { .. } => { - *state = MessageState::Authorities; - } - - ref mut state @ MessageState::MidAdditional { .. } => { - *state = MessageState::Additionals; - } - - _ => unreachable!(), - } + self.builder + .context + .state + .commit(&mut self.builder.message.header.counts); // NOTE: The record data size will be fixed on drop. BuildCommitted @@ -194,21 +183,7 @@ impl<'b> RecordBuilder<'b> { /// The builder will be consumed, and the record will be removed. pub fn cancel(self) { self.builder.context.size = self.name.into(); - match self.builder.context.state { - ref mut state @ MessageState::MidAnswer { .. } => { - *state = MessageState::Answers; - } - - ref mut state @ MessageState::MidAuthority { .. } => { - *state = MessageState::Authorities; - } - - ref mut state @ MessageState::MidAdditional { .. } => { - *state = MessageState::Additionals; - } - - _ => unreachable!(), - } + self.builder.context.state.cancel(); // NOTE: The drop glue is a no-op. } From 9a5e19313bbea2a69d7124573affdff497742184 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 12:40:52 +0100 Subject: [PATCH 101/167] [new_base/wire/parse] Fix miscount of doc test --- src/new_base/wire/parse.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/new_base/wire/parse.rs b/src/new_base/wire/parse.rs index 2a8da6642..2e517c197 100644 --- a/src/new_base/wire/parse.rs +++ b/src/new_base/wire/parse.rs @@ -189,7 +189,7 @@ pub unsafe trait ParseBytesByRef { /// may be provided. Until then, it should be implemented using one of /// the following expressions: /// - /// ```ignore + /// ```text /// fn ptr_with_address( /// &self, /// addr: *const (), From 1a54da1e46d628638d79d22c655a87514be34435 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 13:13:18 +0100 Subject: [PATCH 102/167] [new_base/wire/parse] Support parsing into arrays --- src/new_base/wire/parse.rs | 54 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/src/new_base/wire/parse.rs b/src/new_base/wire/parse.rs index 2e517c197..5ab9bedd5 100644 --- a/src/new_base/wire/parse.rs +++ b/src/new_base/wire/parse.rs @@ -1,6 +1,7 @@ //! Parsing bytes in the basic network format. use core::fmt; +use core::mem::MaybeUninit; //----------- ParseBytes ----------------------------------------------------- @@ -29,6 +30,15 @@ impl<'a, T: ?Sized + ParseBytesByRef> ParseBytes<'a> for &'a T { } } +impl<'a, T: SplitBytes<'a>, const N: usize> ParseBytes<'a> for [T; N] { + fn parse_bytes(bytes: &'a [u8]) -> Result { + match Self::split_bytes(bytes) { + Ok((this, &[])) => Ok(this), + _ => Err(ParseError), + } + } +} + /// Deriving [`ParseBytes`] automatically. /// /// [`ParseBytes`] can be derived on `struct`s (not `enum`s or `union`s). All @@ -86,6 +96,50 @@ impl<'a> SplitBytes<'a> for u8 { } } +impl<'a, T: SplitBytes<'a>, const N: usize> SplitBytes<'a> for [T; N] { + fn split_bytes( + mut bytes: &'a [u8], + ) -> Result<(Self, &'a [u8]), ParseError> { + // TODO: Rewrite when either 'array_try_map' or 'try_array_from_fn' + // is stabilized. + + /// A guard for dropping initialized elements on panic / failure. + struct Guard { + buffer: [MaybeUninit; N], + initialized: usize, + } + + impl Drop for Guard { + fn drop(&mut self) { + for elem in &mut self.buffer[..self.initialized] { + // SAFETY: The first 'initialized' elems are initialized. + unsafe { elem.assume_init_drop() }; + } + } + } + + let mut guard = Guard:: { + buffer: [const { MaybeUninit::uninit() }; N], + initialized: 0, + }; + + while guard.initialized < N { + let (elem, rest) = T::split_bytes(bytes)?; + guard.buffer[guard.initialized].write(elem); + bytes = rest; + guard.initialized += 1; + } + + // Disable the guard since we're moving data out now. + guard.initialized = 0; + + // SAFETY: '[MaybeUninit; N]' and '[T; N]' have the same layout, + // because 'MaybeUninit' and 'T' have the same layout, because it + // is documented in the standard library. + Ok((unsafe { core::mem::transmute_copy(&guard.buffer) }, bytes)) + } +} + /// Deriving [`SplitBytes`] automatically. /// /// [`SplitBytes`] can be derived on `struct`s (not `enum`s or `union`s). All From 00aaaf77a8e07a42d7bd84bf39474de71c0cd510 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 13:16:40 +0100 Subject: [PATCH 103/167] [new_base/build] Accept clippy suggestions --- src/new_base/build/builder.rs | 6 +++--- src/new_base/build/context.rs | 14 +++++++------- src/new_base/build/message.rs | 4 ++-- src/new_base/build/question.rs | 2 +- src/new_base/build/record.rs | 8 ++++---- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index edcc9b543..34c415ce3 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -177,7 +177,7 @@ impl<'b> Builder<'b> { pub fn uncommitted(&self) -> &[u8] { let message = self.contents.get().cast::().cast_const(); // SAFETY: It is guaranteed that 'start <= message.len()'. - let message = unsafe { message.offset(self.start as isize) }; + let message = unsafe { message.add(self.start) }; let size = self.context.size - self.start; // SAFETY: 'message[start..]' is mutably borrowed. unsafe { slice::from_raw_parts(message, size) } @@ -192,7 +192,7 @@ impl<'b> Builder<'b> { pub unsafe fn uncommitted_mut(&mut self) -> &mut [u8] { let message = self.contents.get().cast::(); // SAFETY: It is guaranteed that 'start <= message.len()'. - let message = unsafe { message.offset(self.start as isize) }; + let message = unsafe { message.add(self.start) }; let size = self.context.size - self.start; // SAFETY: 'message[start..]' is mutably borrowed. unsafe { slice::from_raw_parts_mut(message, size) } @@ -206,7 +206,7 @@ impl<'b> Builder<'b> { pub fn uninitialized(&mut self) -> &mut [u8] { let message = self.contents.get().cast::(); // SAFETY: It is guaranteed that 'size <= message.len()'. - let message = unsafe { message.offset(self.context.size as isize) }; + let message = unsafe { message.add(self.context.size) }; let size = self.max_size() - self.context.size; // SAFETY: 'message[size..]' is mutably borrowed. unsafe { slice::from_raw_parts_mut(message, size) } diff --git a/src/new_base/build/context.rs b/src/new_base/build/context.rs index 112330123..e62ad265b 100644 --- a/src/new_base/build/context.rs +++ b/src/new_base/build/context.rs @@ -137,13 +137,13 @@ impl MessageState { /// Whether a question or record is being built. pub const fn mid_component(&self) -> bool { - match self { - Self::MidQuestion { .. } => true, - Self::MidAnswer { .. } => true, - Self::MidAuthority { .. } => true, - Self::MidAdditional { .. } => true, - _ => false, - } + matches!( + self, + Self::MidQuestion { .. } + | Self::MidAnswer { .. } + | Self::MidAuthority { .. } + | Self::MidAdditional { .. } + ) } /// Commit a question or record and update the section counts. diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 3101ecd76..5e969115f 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -52,7 +52,7 @@ impl<'b> MessageBuilder<'b> { //--- Inspection -impl<'b> MessageBuilder<'b> { +impl MessageBuilder<'_> { /// The message header. pub fn header(&self) -> &Header { &self.message.header @@ -135,7 +135,7 @@ impl MessageBuilder<'_> { unsafe { let contents = &mut self.message.contents; let contents = contents as *mut [u8] as *const UnsafeCell<[u8]>; - Builder::from_raw_parts(&*contents, &mut self.context, start) + Builder::from_raw_parts(&*contents, self.context, start) } } diff --git a/src/new_base/build/question.rs b/src/new_base/build/question.rs index 680b828fd..95fa095ae 100644 --- a/src/new_base/build/question.rs +++ b/src/new_base/build/question.rs @@ -91,7 +91,7 @@ impl<'b> QuestionBuilder<'b> { //--- Interaction -impl<'b> QuestionBuilder<'b> { +impl QuestionBuilder<'_> { /// Commit this question. /// /// The builder will be consumed, and the question will be committed so diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index 6a27826c3..f74418a13 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -53,9 +53,9 @@ impl<'b> RecordBuilder<'b> { let name = start.try_into().expect("Messages are at most 64KiB"); let mut b = builder.builder(start); record.rname.build_into_message(b.delegate())?; - b.append_bytes(&record.rtype.as_bytes())?; - b.append_bytes(&record.rclass.as_bytes())?; - b.append_bytes(&record.ttl.as_bytes())?; + b.append_bytes(record.rtype.as_bytes())?; + b.append_bytes(record.rclass.as_bytes())?; + b.append_bytes(record.ttl.as_bytes())?; let size = b.context().size; SizePrefixed::new(&record.rdata) .build_into_message(b.delegate())?; @@ -163,7 +163,7 @@ impl<'b> RecordBuilder<'b> { //--- Interaction -impl<'b> RecordBuilder<'b> { +impl RecordBuilder<'_> { /// Commit this record. /// /// The builder will be consumed, and the record will be committed so that From 1112224cd472ba69c8f46d0cea87ab2c9a4e2a82 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 13:20:25 +0100 Subject: [PATCH 104/167] [new_rdata] Use 'core::net::Ipv4Addr' --- src/new_rdata/basic.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 43d089100..2f4e24035 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -3,13 +3,9 @@ //! See [RFC 1035](https://datatracker.ietf.org/doc/html/rfc1035). use core::fmt; - -#[cfg(feature = "std")] +use core::net::Ipv4Addr; use core::str::FromStr; -#[cfg(feature = "std")] -use std::net::Ipv4Addr; - use domain_macros::*; use crate::new_base::{ @@ -46,7 +42,6 @@ pub struct A { //--- Converting to and from 'Ipv4Addr' -#[cfg(feature = "std")] impl From for A { fn from(value: Ipv4Addr) -> Self { Self { @@ -55,7 +50,6 @@ impl From for A { } } -#[cfg(feature = "std")] impl From for Ipv4Addr { fn from(value: A) -> Self { Self::from(value.octets) @@ -64,7 +58,6 @@ impl From for Ipv4Addr { //--- Parsing from a string -#[cfg(feature = "std")] impl FromStr for A { type Err = ::Err; @@ -77,8 +70,7 @@ impl FromStr for A { impl fmt::Display for A { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let [a, b, c, d] = self.octets; - write!(f, "{a}.{b}.{c}.{d}") + Ipv4Addr::from(*self).fmt(f) } } From c808946addf755900864e38e7fd02c938e745346 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 13:25:04 +0100 Subject: [PATCH 105/167] [new_rdata] Document format in 'Txt' See: --- src/new_rdata/basic.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs index 2f4e24035..e880721da 100644 --- a/src/new_rdata/basic.rs +++ b/src/new_rdata/basic.rs @@ -437,6 +437,8 @@ impl BuildIntoMessage for Mx { #[repr(transparent)] pub struct Txt { /// The text strings, as concatenated [`CharStr`]s. + /// + /// The [`CharStr`]s begin with a length octet so they can be separated. content: [u8], } From 2d3e9110fb541e5cd6ee4dbde67e426718cf9c2d Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 24 Jan 2025 18:17:30 +0100 Subject: [PATCH 106/167] [new_base/name] Add 'LabelBuf' and support building on 'Label' --- src/new_base/name/label.rs | 221 ++++++++++++++++++++++++++++++++++++- src/new_base/name/mod.rs | 2 +- 2 files changed, 220 insertions(+), 3 deletions(-) diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index 9cb4d1d85..3c5c44239 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -1,15 +1,21 @@ //! Labels in domain names. use core::{ + borrow::{Borrow, BorrowMut}, cmp::Ordering, fmt, hash::{Hash, Hasher}, iter::FusedIterator, + ops::{Deref, DerefMut}, }; use domain_macros::AsBytes; -use crate::new_base::wire::{ParseBytes, ParseError, SplitBytes}; +use crate::new_base::{ + build::{BuildIntoMessage, BuildResult, Builder}, + parse::{ParseMessageBytes, SplitMessageBytes}, + wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, +}; //----------- Label ---------------------------------------------------------- @@ -48,9 +54,52 @@ impl Label { // SAFETY: 'Label' is 'repr(transparent)' to '[u8]'. unsafe { core::mem::transmute(bytes) } } + + /// Assume a mutable byte slice is a valid label. + /// + /// # Safety + /// + /// The byte slice must have length 63 or less. + pub unsafe fn from_bytes_unchecked_mut(bytes: &mut [u8]) -> &mut Self { + // SAFETY: 'Label' is 'repr(transparent)' to '[u8]'. + unsafe { core::mem::transmute(bytes) } + } } -//--- Parsing +//--- Parsing from DNS messages + +impl<'a> ParseMessageBytes<'a> for &'a Label { + fn parse_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result { + Self::parse_bytes(&contents[start..]) + } +} + +impl<'a> SplitMessageBytes<'a> for &'a Label { + fn split_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result<(Self, usize), ParseError> { + Self::split_bytes(&contents[start..]) + .map(|(this, rest)| (this, contents.len() - start - rest.len())) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for Label { + fn build_into_message(&self, mut builder: Builder<'_>) -> BuildResult { + builder.append_with(self.len() + 1, |buf| { + buf[0] = self.len() as u8; + buf[1..].copy_from_slice(self.as_bytes()); + })?; + Ok(builder.commit()) + } +} + +//--- Parsing from bytes impl<'a> SplitBytes<'a> for &'a Label { fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { @@ -73,6 +122,20 @@ impl<'a> ParseBytes<'a> for &'a Label { } } +//--- Building into byte strings + +impl BuildBytes for Label { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + let (size, data) = bytes.split_first_mut().ok_or(TruncationError)?; + let rest = self.as_bytes().build_bytes(data)?; + *size = self.len() as u8; + Ok(rest) + } +} + //--- Inspection impl Label { @@ -209,6 +272,160 @@ impl fmt::Debug for Label { } } +//----------- LabelBuf ------------------------------------------------------- + +/// A 64-byte buffer holding a [`Label`]. +#[derive(Clone)] +#[repr(C)] // make layout compatible with '[u8; 64]' +pub struct LabelBuf { + /// The size of the label, in bytes. + /// + /// This value is guaranteed to be in the range '0..64'. + size: u8, + + /// The underlying label data. + data: [u8; 63], +} + +//--- Construction + +impl LabelBuf { + /// Copy a [`Label`] into a buffer. + pub fn copy_from(label: &Label) -> Self { + let size = label.len() as u8; + let mut data = [0u8; 63]; + data[..size as usize].copy_from_slice(label.as_bytes()); + Self { size, data } + } +} + +//--- Parsing from DNS messages + +impl ParseMessageBytes<'_> for LabelBuf { + fn parse_message_bytes( + contents: &'_ [u8], + start: usize, + ) -> Result { + Self::parse_bytes(&contents[start..]) + } +} + +impl SplitMessageBytes<'_> for LabelBuf { + fn split_message_bytes( + contents: &'_ [u8], + start: usize, + ) -> Result<(Self, usize), ParseError> { + Self::split_bytes(&contents[start..]) + .map(|(this, rest)| (this, contents.len() - start - rest.len())) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for LabelBuf { + fn build_into_message(&self, builder: Builder<'_>) -> BuildResult { + (**self).build_into_message(builder) + } +} + +//--- Parsing from byte strings + +impl ParseBytes<'_> for LabelBuf { + fn parse_bytes(bytes: &[u8]) -> Result { + <&Label>::parse_bytes(bytes).map(Self::copy_from) + } +} + +impl SplitBytes<'_> for LabelBuf { + fn split_bytes(bytes: &'_ [u8]) -> Result<(Self, &'_ [u8]), ParseError> { + <&Label>::split_bytes(bytes) + .map(|(label, rest)| (Self::copy_from(label), rest)) + } +} + +//--- Building into byte strings + +impl BuildBytes for LabelBuf { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + (**self).build_bytes(bytes) + } +} + +//--- Access to the underlying 'Label' + +impl Deref for LabelBuf { + type Target = Label; + + fn deref(&self) -> &Self::Target { + let label = &self.data[..self.size as usize]; + // SAFETY: A 'LabelBuf' always contains a valid 'Label'. + unsafe { Label::from_bytes_unchecked(label) } + } +} + +impl DerefMut for LabelBuf { + fn deref_mut(&mut self) -> &mut Self::Target { + let label = &mut self.data[..self.size as usize]; + // SAFETY: A 'LabelBuf' always contains a valid 'Label'. + unsafe { Label::from_bytes_unchecked_mut(label) } + } +} + +impl Borrow::respond(&self.0, request).await { ControlFlow::Continue(at) => at, ControlFlow::Break(ap) => return ControlFlow::Break(Left(ap)), }; - match self.1.respond(request.1) { + match ::respond(&self.1, request).await { ControlFlow::Continue(bt) => ControlFlow::Continue((at, bt)), ControlFlow::Break(bp) => ControlFlow::Break(Right((at, bp))), } @@ -289,40 +268,21 @@ where macro_rules! impl_service_layer_tuple { ($first:ident .. $last:ident: $($middle:ident)+) => { - impl<'req, $first, $($middle,)+ $last: ?Sized> - ServiceLayer<'req> for ($first, $($middle,)+ $last) + impl<$first, $($middle,)+ $last: ?Sized> + ServiceLayer for ($first, $($middle,)+ $last) where - $first: ServiceLayer<'req>, - $($middle: ServiceLayer<'req>,)+ - $last: ServiceLayer<'req>, + $first: ServiceLayer, + $($middle: ServiceLayer,)+ + $last: ServiceLayer, { - type Consumer = - <($first, ($($middle,)+ $last)) as ServiceLayer<'req>> - ::Consumer; - - type Producer = - <($first, ($($middle,)+ $last)) as ServiceLayer<'req>> - ::Producer; - - type Transformer = - <($first, ($($middle,)+ $last)) as ServiceLayer<'req>> - ::Transformer; - - fn consume(&self) -> Self::Consumer - { - #[allow(non_snake_case)] - let ($first, $($middle,)+ $last) = self; - ($first, ($($middle,)+ $last)).consume() - } - - fn respond( + async fn respond( &self, - request: Self::Consumer, + request: &RequestMessage<'_>, ) -> ControlFlow { #[allow(non_snake_case)] let ($first, $($middle,)+ ref $last) = self; - ($first, ($($middle,)+ $last)).respond(request) + ($first, ($($middle,)+ $last)).respond(request).await } } } @@ -340,24 +300,14 @@ impl_service_layer_tuple!(A..K: B C D E F G H I J); impl_service_layer_tuple!(A..L: B C D E F G H I J K); #[cfg(feature = "std")] -impl<'req, T: ServiceLayer<'req>> ServiceLayer<'req> for [T] { - type Consumer = Box<[T::Consumer]>; - type Producer = (Box<[T::Transformer]>, T::Producer); - type Transformer = Box<[T::Transformer]>; - - fn consume(&self) -> Self::Consumer { - self.iter().map(T::consume).collect() - } - - fn respond( +impl ServiceLayer for [T] { + async fn respond( &self, - request: Self::Consumer, + request: &RequestMessage<'_>, ) -> ControlFlow { let mut transformers = Vec::new(); - // TODO (MSRV 1.80): Use Box<[T]>: IntoIterator - let request: Vec<_> = request.into(); - for (layer, request) in self.iter().zip(request) { - match layer.respond(request) { + for layer in self { + match layer.respond(request).await { ControlFlow::Continue(t) => transformers.push(t), ControlFlow::Break(p) => { return ControlFlow::Break((transformers.into(), p)); @@ -369,256 +319,183 @@ impl<'req, T: ServiceLayer<'req>> ServiceLayer<'req> for [T] { } #[cfg(feature = "std")] -impl<'req, T: ServiceLayer<'req>> ServiceLayer<'req> for Vec { - type Consumer = Box<[T::Consumer]>; - type Producer = (Box<[T::Transformer]>, T::Producer); - type Transformer = Box<[T::Transformer]>; - - fn consume(&self) -> Self::Consumer { - self.as_slice().consume() +impl ServiceLayer for Vec { + async fn respond( + &self, + request: &RequestMessage<'_>, + ) -> ControlFlow { + self.as_slice().respond(request).await } +} + +//----------- impl LocalServiceLayer ----------------------------------------- + +impl LocalServiceLayer for &T { + type Producer = T::Producer; - fn respond( + type Transformer = T::Transformer; + + async fn respond_local( &self, - request: Self::Consumer, + request: &RequestMessage<'_>, ) -> ControlFlow { - self.as_slice().respond(request) + T::respond_local(self, request).await } } -//----------- impl ConsumeMessage -------------------------------------------- +impl LocalServiceLayer for &mut T { + type Producer = T::Producer; -impl<'msg, T: ?Sized + ConsumeMessage<'msg>> ConsumeMessage<'msg> for &mut T { - fn consume_header(&mut self, header: &'msg Header) { - T::consume_header(self, header); - } + type Transformer = T::Transformer; - fn consume_question(&mut self, question: &Question) { - T::consume_question(self, question); + async fn respond_local( + &self, + request: &RequestMessage<'_>, + ) -> ControlFlow { + T::respond_local(self, request).await } +} - fn consume_answer( - &mut self, - answer: &Record>, - ) { - T::consume_answer(self, answer); - } +#[cfg(feature = "std")] +impl LocalServiceLayer for Box { + type Producer = T::Producer; - fn consume_authority( - &mut self, - authority: &Record>, - ) { - T::consume_authority(self, authority); - } + type Transformer = T::Transformer; - fn consume_additional( - &mut self, - additional: &Record>, - ) { - T::consume_additional(self, additional); + async fn respond_local( + &self, + request: &RequestMessage<'_>, + ) -> ControlFlow { + T::respond_local(self, request).await } } #[cfg(feature = "std")] -impl<'msg, T: ?Sized + ConsumeMessage<'msg>> ConsumeMessage<'msg> for Box { - fn consume_header(&mut self, header: &'msg Header) { - T::consume_header(self, header); - } +impl LocalServiceLayer for Rc { + type Producer = T::Producer; - fn consume_question(&mut self, question: &Question) { - T::consume_question(self, question); - } + type Transformer = T::Transformer; - fn consume_answer( - &mut self, - answer: &Record>, - ) { - T::consume_answer(self, answer); + async fn respond_local( + &self, + request: &RequestMessage<'_>, + ) -> ControlFlow { + T::respond_local(self, request).await } +} - fn consume_authority( - &mut self, - authority: &Record>, - ) { - T::consume_authority(self, authority); - } +#[cfg(feature = "std")] +impl LocalServiceLayer for Arc { + type Producer = T::Producer; - fn consume_additional( - &mut self, - additional: &Record>, - ) { - T::consume_additional(self, additional); + type Transformer = T::Transformer; + + async fn respond_local( + &self, + request: &RequestMessage<'_>, + ) -> ControlFlow { + T::respond_local(self, request).await } } -impl<'msg, A, B> ConsumeMessage<'msg> for Either +impl LocalServiceLayer for (A, B) where - A: ConsumeMessage<'msg>, - B: ConsumeMessage<'msg>, + A: LocalServiceLayer, + B: LocalServiceLayer, { - fn consume_header(&mut self, header: &'msg Header) { - either::for_both!(self, x => x.consume_header(header)) - } - - fn consume_question(&mut self, question: &Question) { - either::for_both!(self, x => x.consume_question(question)) - } + type Producer = Either; - fn consume_answer( - &mut self, - answer: &Record>, - ) { - either::for_both!(self, x => x.consume_answer(answer)) - } + type Transformer = (A::Transformer, B::Transformer); - fn consume_authority( - &mut self, - authority: &Record>, - ) { - either::for_both!(self, x => x.consume_authority(authority)) - } + async fn respond_local( + &self, + request: &RequestMessage<'_>, + ) -> ControlFlow { + let at = match self.0.respond_local(request).await { + ControlFlow::Continue(at) => at, + ControlFlow::Break(ap) => return ControlFlow::Break(Left(ap)), + }; - fn consume_additional( - &mut self, - additional: &Record>, - ) { - either::for_both!(self, x => x.consume_additional(additional)) + match self.1.respond_local(request).await { + ControlFlow::Continue(bt) => ControlFlow::Continue((at, bt)), + ControlFlow::Break(bp) => ControlFlow::Break(Right((at, bp))), + } } } -macro_rules! impl_consume_message_tuple { - ($($middle:ident)* .. $last:ident) => { - impl<'msg, $($middle,)* $last: ?Sized> - ConsumeMessage<'msg> for ($($middle,)* $last,) +macro_rules! impl_local_service_layer_tuple { + ($first:ident .. $last:ident: $($middle:ident)+) => { + impl<$first, $($middle,)+ $last: ?Sized> + LocalServiceLayer for ($first, $($middle,)+ $last) where - $($middle: ConsumeMessage<'msg>,)* - $last: ConsumeMessage<'msg>, + $first: LocalServiceLayer, + $($middle: LocalServiceLayer,)+ + $last: LocalServiceLayer, { - fn consume_header(&mut self, header: &'msg Header) { - #[allow(non_snake_case)] - let ($($middle,)* ref mut $last,) = self; - $($middle.consume_header(header);)* - $last.consume_header(header); - } - - fn consume_question(&mut self, question: &Question) { - #[allow(non_snake_case)] - let ($($middle,)* ref mut $last,) = self; - $($middle.consume_question(question);)* - $last.consume_question(question); - } - - fn consume_answer( - &mut self, - answer: &Record>, - ) { - #[allow(non_snake_case)] - let ($($middle,)* ref mut $last,) = self; - $($middle.consume_answer(answer);)* - $last.consume_answer(answer); - } + type Producer = + <($first, ($($middle,)+ $last)) as LocalServiceLayer> + ::Producer; - fn consume_authority( - &mut self, - authority: &Record>, - ) { - #[allow(non_snake_case)] - let ($($middle,)* ref mut $last,) = self; - $($middle.consume_authority(authority);)* - $last.consume_authority(authority); - } + type Transformer = + <($first, ($($middle,)+ $last)) as LocalServiceLayer> + ::Transformer; - fn consume_additional( - &mut self, - additional: &Record>, - ) { + async fn respond_local( + &self, + request: &RequestMessage<'_>, + ) -> ControlFlow + { #[allow(non_snake_case)] - let ($($middle,)* ref mut $last,) = self; - $($middle.consume_additional(additional);)* - $last.consume_additional(additional); + let ($first, $($middle,)+ ref $last) = self; + ($first, ($($middle,)+ $last)).respond_local(request).await } } - }; -} - -impl_consume_message_tuple!(..A); -impl_consume_message_tuple!(A..B); -impl_consume_message_tuple!(A B..C); -impl_consume_message_tuple!(A B C..D); -impl_consume_message_tuple!(A B C D..E); -impl_consume_message_tuple!(A B C D E..F); -impl_consume_message_tuple!(A B C D E F..G); -impl_consume_message_tuple!(A B C D E F G..H); -impl_consume_message_tuple!(A B C D E F G H..I); -impl_consume_message_tuple!(A B C D E F G H I..J); -impl_consume_message_tuple!(A B C D E F G H I J..K); -impl_consume_message_tuple!(A B C D E F G H I J K..L); - -impl<'msg, T: ConsumeMessage<'msg>> ConsumeMessage<'msg> for [T] { - fn consume_header(&mut self, header: &'msg Header) { - self.iter_mut() - .for_each(|layer| layer.consume_header(header)); - } - - fn consume_question(&mut self, question: &Question) { - self.iter_mut() - .for_each(|layer| layer.consume_question(question)); } +} - fn consume_answer( - &mut self, - answer: &Record>, - ) { - self.iter_mut() - .for_each(|layer| layer.consume_answer(answer)); - } +impl_local_service_layer_tuple!(A..C: B); +impl_local_service_layer_tuple!(A..D: B C); +impl_local_service_layer_tuple!(A..E: B C D); +impl_local_service_layer_tuple!(A..F: B C D E); +impl_local_service_layer_tuple!(A..G: B C D E F); +impl_local_service_layer_tuple!(A..H: B C D E F G); +impl_local_service_layer_tuple!(A..I: B C D E F G H); +impl_local_service_layer_tuple!(A..J: B C D E F G H I); +impl_local_service_layer_tuple!(A..K: B C D E F G H I J); +impl_local_service_layer_tuple!(A..L: B C D E F G H I J K); - fn consume_authority( - &mut self, - authority: &Record>, - ) { - self.iter_mut() - .for_each(|layer| layer.consume_authority(authority)); - } +#[cfg(feature = "std")] +impl LocalServiceLayer for [T] { + type Producer = (Box<[T::Transformer]>, T::Producer); + type Transformer = Box<[T::Transformer]>; - fn consume_additional( - &mut self, - additional: &Record>, - ) { - self.iter_mut() - .for_each(|layer| layer.consume_additional(additional)); + async fn respond_local( + &self, + request: &RequestMessage<'_>, + ) -> ControlFlow { + let mut transformers = Vec::new(); + for layer in self { + match layer.respond_local(request).await { + ControlFlow::Continue(t) => transformers.push(t), + ControlFlow::Break(p) => { + return ControlFlow::Break((transformers.into(), p)); + } + } + } + ControlFlow::Continue(transformers.into()) } } #[cfg(feature = "std")] -impl<'msg, T: ConsumeMessage<'msg>> ConsumeMessage<'msg> for Vec { - fn consume_header(&mut self, header: &'msg Header) { - self.as_mut_slice().consume_header(header) - } - - fn consume_question(&mut self, question: &Question) { - self.as_mut_slice().consume_question(question) - } - - fn consume_answer( - &mut self, - answer: &Record>, - ) { - self.as_mut_slice().consume_answer(answer) - } - - fn consume_authority( - &mut self, - authority: &Record>, - ) { - self.as_mut_slice().consume_authority(authority) - } +impl LocalServiceLayer for Vec { + type Producer = (Box<[T::Transformer]>, T::Producer); + type Transformer = Box<[T::Transformer]>; - fn consume_additional( - &mut self, - additional: &Record>, - ) { - self.as_mut_slice().consume_additional(additional) + async fn respond_local( + &self, + request: &RequestMessage<'_>, + ) -> ControlFlow { + self.as_slice().respond_local(request).await } } diff --git a/src/new_net/server/mod.rs b/src/new_net/server/mod.rs index 87b6dd8bc..cf8b7d358 100644 --- a/src/new_net/server/mod.rs +++ b/src/new_net/server/mod.rs @@ -14,43 +14,64 @@ #![cfg(feature = "unstable-server-transport")] #![cfg_attr(docsrs, doc(cfg(feature = "unstable-server-transport")))] -use core::ops::ControlFlow; +use core::{future::Future, ops::ControlFlow}; -use crate::{ - new_base::{ - build::{MessageBuilder, QuestionBuilder, RecordBuilder}, - name::RevNameBuf, - Header, Question, Record, - }, - new_rdata::RecordData, +use crate::new_base::{ + build::{MessageBuilder, QuestionBuilder, RecordBuilder}, + Header, }; mod impls; +mod request; +pub use request::RequestMessage; + //----------- Service -------------------------------------------------------- -/// A DNS service, that computes responses for requests. +/// A (multi-threaded) DNS service, that computes responses for requests. /// /// Given a DNS request message, a service computes an appropriate response. /// Services are usually wrapped in a network transport that receives requests /// and returns the service's responses. /// +/// Use [`LocalService`] for a single-threaded equivalent. +/// /// # Layering /// /// Additional functionality can be added to a service by prefixing it with /// service layers, usually in a tuple. A number of blanket implementations /// are provided to simplify this. -pub trait Service<'req> { - /// A consumer of DNS requests. - /// - /// This type is given access to every component in a DNS request message. - /// It should store only the information relevant to this service. +pub trait Service: LocalService + Sync { + /// Respond to a DNS request. /// - /// # Lifetimes + /// The provided consumer must have been provided the entire DNS request + /// message. This method will use the extracted information to formulate + /// a response message, in the form of a producer type. /// - /// The consumer can borrow from the request message (`'req`). - type Consumer: ConsumeMessage<'req>; + /// The returned future implements [`Send`]. Use [`LocalService`] and + /// [`LocalService::respond_local()`] if [`Send`] is not necessary. + fn respond( + &self, + request: &RequestMessage<'_>, + ) -> impl Future + Send; +} +//----------- LocalService --------------------------------------------------- + +/// A (single-threaded) DNS service, that computes responses for requests. +/// +/// Given a DNS request message, a service computes an appropriate response. +/// Services are usually wrapped in a network transport that receives requests +/// and returns the service's responses. +/// +/// Use [`Service`] for a multi-threaded equivalent. +/// +/// # Layering +/// +/// Additional functionality can be added to a service by prefixing it with +/// service layers, usually in a tuple. A number of blanket implementations +/// are provided to simplify this. +pub trait LocalService { /// A producer of DNS responses. /// /// This type returns components to insert in a DNS response message. It @@ -62,43 +83,69 @@ pub trait Service<'req> { /// it cannot borrow from the response message. type Producer: ProduceMessage; - /// Consume a DNS request. - /// - /// The returned consumer should be provided with a DNS request message. - /// After the whole message is consumed, a response message can be built - /// using [`Self::respond()`]. - fn consume(&self) -> Self::Consumer; - /// Respond to a DNS request. /// /// The provided consumer must have been provided the entire DNS request /// message. This method will use the extracted information to formulate /// a response message, in the form of a producer type. - fn respond(&self, request: Self::Consumer) -> Self::Producer; + /// + /// The returned future does not implement [`Send`]. Use [`Service`] and + /// [`Service::respond()`] for a [`Send`]-implementing version. + #[allow(async_fn_in_trait)] + async fn respond_local( + &self, + request: &RequestMessage<'_>, + ) -> Self::Producer; } //----------- ServiceLayer --------------------------------------------------- -/// A layer wrapping a DNS [`Service`]. +/// A (multi-threaded) layer wrapping a DNS [`Service`]. /// /// A layer can be wrapped around a service, inspecting the requests sent to /// it and transforming the responses returned by it. /// +/// Use [`LocalServiceLayer`] for a single-threaded equivalent. +/// /// # Combinations /// /// Layers can be combined (usually in a tuple) into larger layers. A number /// of blanket implementations are provided to simplify this. -pub trait ServiceLayer<'req> { - /// A consumer of DNS requests. - /// - /// This type is given access to every component in a DNS request message. - /// It should store only the information relevant to this service. +pub trait ServiceLayer: + LocalServiceLayer + Sync +{ + /// Respond to a DNS request. /// - /// # Lifetimes + /// The provided consumer must have been provided the entire DNS request + /// message. If the request should be forwarded through to the wrapped + /// service, [`ControlFlow::Continue`] is returned, with a transformer to + /// modify the wrapped service's response. However, if the request should + /// be responded to directly by this layer, without any interaction from + /// the wrapped service, [`ControlFlow::Break`] is returned. /// - /// The consumer can borrow from the request message (`'req`). - type Consumer: ConsumeMessage<'req>; + /// The returned future implements [`Send`]. Use [`LocalServiceLayer`] + /// and [`LocalServiceLayer::respond_local()`] if [`Send`] is not + /// necessary. + fn respond( + &self, + request: &RequestMessage<'_>, + ) -> impl Future> + Send; +} + +//----------- LocalServiceLayer ---------------------------------------------- +/// A (single-threaded) layer wrapping a DNS [`Service`]. +/// +/// A layer can be wrapped around a service, inspecting the requests sent to +/// it and transforming the responses returned by it. +/// +/// Use [`ServiceLayer`] for a multi-threaded equivalent. +/// +/// # Combinations +/// +/// Layers can be combined (usually in a tuple) into larger layers. A number +/// of blanket implementations are provided to simplify this. +pub trait LocalServiceLayer { /// A producer of DNS responses. /// /// This type returns components to insert in a DNS response message. It @@ -123,14 +170,6 @@ pub trait ServiceLayer<'req> { /// that it cannot borrow from the response message. type Transformer: TransformMessage; - /// Consume a DNS request. - /// - /// The returned consumer should be provided with a DNS request message, - /// that should also be provided to the wrapped service. After the whole - /// message is consumed, the wrapped service's response can be transformed - /// using [`Self::respond()`]. - fn consume(&self) -> Self::Consumer; - /// Respond to a DNS request. /// /// The provided consumer must have been provided the entire DNS request @@ -139,108 +178,17 @@ pub trait ServiceLayer<'req> { /// modify the wrapped service's response. However, if the request should /// be responded to directly by this layer, without any interaction from /// the wrapped service, [`ControlFlow::Break`] is returned. - fn respond( + /// + /// The returned future does not implement [`Send`]. Use [`ServiceLayer`] + /// and [`ServiceLayer::respond_local()`] for a [`Send`]-implementing + /// version. + #[allow(async_fn_in_trait)] + async fn respond_local( &self, - request: Self::Consumer, + request: &RequestMessage<'_>, ) -> ControlFlow; } -//----------- ConsumeMessage ------------------------------------------------- - -/// A type that consumes a DNS message. -/// -/// This interface is akin to a visitor pattern; its methods get called on -/// every component of the DNS message. Implementing types are expected to -/// extract whatever information they need and ignore most of the input. -/// -/// The consumer's methods should only be called in a fixed order, as they are -/// laid out in the wire format of a DNS message: the header, then questions, -/// answers, authorities, and additional records. -/// -/// # Lifetimes -/// -/// Implementing types can borrow from the message they are consuming; the -/// message is offered for the lifetime `'msg`. However, decompressed names -/// are offered for a temporary lifetime, as they would otherwise have to be -/// allocated on the heap. -/// -/// # Architecture -/// -/// This interface is convenient when multiple independent consumers need to -/// consume the same message. Rather than forcing each consumer to iterate -/// over the entire message every time, including resolving compressed names, -/// this interface allows a message to be iterated through once but examined -/// by an arbitrary number of consumers. For this reason, a number of blanket -/// implementations (e.g. for tuples and slices) are provided. -/// -/// # Examples -/// -/// ``` -/// # use domain::new_base::{Question, name::RevNameBuf}; -/// # use domain::new_net::server::ConsumeMessage; -/// -/// /// A type that extracts the first question in a DNS message. -/// struct FirstQuestion(Option>); -/// -/// impl ConsumeMessage<'_> for FirstQuestion { -/// fn consume_question(&mut self, question: &Question) { -/// if self.0.is_none() { -/// self.0 = Some(question.clone()); -/// } -/// } -/// } -/// ``` -pub trait ConsumeMessage<'msg> { - /// Consume the header of the message. - fn consume_header(&mut self, header: &'msg Header) { - let _ = header; - } - - /// Consume a DNS question. - /// - /// The question is offered for a temporary lifetime because it contains a - /// decompressed name, which is stored outside the original message. - fn consume_question(&mut self, question: &Question) { - let _ = question; - } - - /// Consume a DNS answer record. - /// - /// The record is offered for a temporary lifetime because it contains a - /// decompressed name, which is stored outside the original message. - /// However, record data may be referenced from the original message. - fn consume_answer( - &mut self, - answer: &Record>, - ) { - let _ = answer; - } - - /// Consume a DNS authority record. - /// - /// The record is offered for a temporary lifetime because it contains a - /// decompressed name, which is stored outside the original message. - /// However, record data may be referenced from the original message. - fn consume_authority( - &mut self, - authority: &Record>, - ) { - let _ = authority; - } - - /// Consume a DNS additional record. - /// - /// The record is offered for a temporary lifetime because it contains a - /// decompressed name, which is stored outside the original message. - /// However, record data may be referenced from the original message. - fn consume_additional( - &mut self, - additional: &Record>, - ) { - let _ = additional; - } -} - //----------- ProduceMessage ------------------------------------------------- /// A type that produces a DNS message. diff --git a/src/new_net/server/request.rs b/src/new_net/server/request.rs new file mode 100644 index 000000000..1d9b59ef2 --- /dev/null +++ b/src/new_net/server/request.rs @@ -0,0 +1,18 @@ +//! DNS request messages. + +use crate::new_base::Message; + +/// A DNS request message. +pub struct RequestMessage<'b> { + /// The underlying [`Message`]. + pub message: &'b Message, + + /// Cached indices of the initial questions and records. + indices: [(u16, u16); 8], + + /// Cached indices of the EDNS options in the message. + edns_indices: [(u16, u16); 8], + + /// The number of components before the end of every section. + section_offsets: [u16; 4], +} From 73e135af27b488d699ad40d7659b88033c0820b5 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 28 Jan 2025 15:59:54 +0100 Subject: [PATCH 117/167] [new_net/server/request] Implement construction --- src/new_net/server/request.rs | 134 +++++++++++++++++++++++++++++++++- 1 file changed, 133 insertions(+), 1 deletion(-) diff --git a/src/new_net/server/request.rs b/src/new_net/server/request.rs index 1d9b59ef2..e266be08c 100644 --- a/src/new_net/server/request.rs +++ b/src/new_net/server/request.rs @@ -1,6 +1,11 @@ //! DNS request messages. -use crate::new_base::Message; +use crate::new_base::{ + name::UnparsedName, + parse::SplitMessageBytes, + wire::{AsBytes, ParseError, SizePrefixed, U16}, + Message, Question, RType, Record, UnparsedRecordData, +}; /// A DNS request message. pub struct RequestMessage<'b> { @@ -8,11 +13,138 @@ pub struct RequestMessage<'b> { pub message: &'b Message, /// Cached indices of the initial questions and records. + /// + /// For questions, the indices span the whole question. + /// + /// For records, the indices span the record header (the record name, + /// type, class, TTL, and data size). The record data can be located + /// using the header very easily. indices: [(u16, u16); 8], + /// Cached offset of the EDNS record. + edns_offset: u16, + /// Cached indices of the EDNS options in the message. edns_indices: [(u16, u16); 8], /// The number of components before the end of every section. section_offsets: [u16; 4], } + +//--- Construction + +impl<'b> RequestMessage<'b> { + /// Wrap a raw [`Message`] into a [`RequestMessage`]. + /// + /// This will iterate through the message, pre-filling some caches for + /// efficient access in the future. + pub fn new(message: &'b Message) -> Result { + let mut indices = [(0u16, 0u16); 8]; + let mut edns_indices = [(0u16, 0u16); 8]; + + // DNS messages are 64KiB at the largest. + let _ = u16::try_from(message.as_bytes().len()) + .map_err(|_| ParseError)?; + + // The offset (in bytes) into the message contents. + let mut offset = 0; + + // The section counts from the message. + let counts = &message.header.counts.as_array(); + + // The offset of each section, in components. + let mut section_offsets = [0u16; 4]; + + // First, parse all questions. + for i in 0..counts[0].get() { + let (_question, rest) = + Question::<&'b UnparsedName>::split_message_bytes( + &message.contents, + offset, + )?; + + if let Some(indices) = indices.get_mut(i as usize) { + *indices = (offset as u16, rest as u16); + } + + offset = rest; + } + + // The offset (in components) of this section in the message. + let mut section_offset = counts[0].get(); + section_offsets[0] = section_offset; + + // The offset of the EDNS record, if any. + let mut edns_offset = u16::MAX; + + // Parse all records. + for section in 1..4 { + for i in 0..counts[section].get() { + let (record, rest) = Record::< + &'b UnparsedName, + &'b UnparsedRecordData, + >::split_message_bytes( + &message.contents, offset + )?; + + let component = (section_offset + i) as usize; + if let Some(indices) = indices.get_mut(component) { + let data = offset + record.rname.len() + 10; + *indices = (offset as u16, data as u16); + } + + if record.rtype == RType::OPT { + if edns_offset != u16::MAX { + // A DNS message can only contain one EDNS record. + return Err(ParseError); + } else { + edns_offset = offset as u16; + } + } + + offset = rest; + } + + section_offset += counts[section].get(); + section_offsets[section] = section_offset; + } + + // Parse EDNS options. + if edns_offset < u16::MAX { + // Extract the EDNS record data. + let offset = edns_offset as usize + 9; + let (&size, mut offset) = + <&U16>::split_message_bytes(&message.contents, offset)?; + + let contents = message + .contents + .get(..offset + size.get() as usize) + .ok_or(ParseError)?; + + // Parse through it. + let mut indices = edns_indices.iter_mut(); + while offset < contents.len() { + let (_type, rest) = + <&U16>::split_message_bytes(contents, offset)?; + let (_data, rest) = + >::split_message_bytes( + contents, rest, + )?; + + if let Some(indices) = indices.next() { + *indices = (offset as u16, rest as u16); + } + + offset = rest; + } + } + + Ok(Self { + message, + indices, + edns_offset, + edns_indices, + section_offsets, + }) + } +} From 15e6b10eb500fcd5a1220e0780509ab82379f4c1 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 29 Jan 2025 11:44:25 +0100 Subject: [PATCH 118/167] [new_net/server/request] Rewrite init a bit more cleanly --- src/new_net/server/request.rs | 225 +++++++++++++++++++++------------- 1 file changed, 140 insertions(+), 85 deletions(-) diff --git a/src/new_net/server/request.rs b/src/new_net/server/request.rs index e266be08c..86808a0f3 100644 --- a/src/new_net/server/request.rs +++ b/src/new_net/server/request.rs @@ -1,34 +1,34 @@ //! DNS request messages. +use core::ops::Range; + use crate::new_base::{ name::UnparsedName, - parse::SplitMessageBytes, + parse::{ParseMessageBytes, SplitMessageBytes}, wire::{AsBytes, ParseError, SizePrefixed, U16}, Message, Question, RType, Record, UnparsedRecordData, }; /// A DNS request message. +#[derive(Clone)] pub struct RequestMessage<'b> { /// The underlying [`Message`]. pub message: &'b Message, - /// Cached indices of the initial questions and records. - /// - /// For questions, the indices span the whole question. - /// - /// For records, the indices span the record header (the record name, - /// type, class, TTL, and data size). The record data can be located - /// using the header very easily. - indices: [(u16, u16); 8], + /// Cached offsets for the question section. + questions: (Range, [Range; 1]), + + /// Cached offsets for the answer section. + answers: (Range, [Range; 0]), - /// Cached offset of the EDNS record. - edns_offset: u16, + /// Cached offsets for the authority section. + authorities: (Range, [Range; 0]), - /// Cached indices of the EDNS options in the message. - edns_indices: [(u16, u16); 8], + /// Cached offsets for the additional section. + additional: (Range, [Range; 2]), - /// The number of components before the end of every section. - section_offsets: [u16; 4], + /// Cached offsets for the EDNS record. + edns: (Range, u16, [Range; 4]), } //--- Construction @@ -39,91 +39,86 @@ impl<'b> RequestMessage<'b> { /// This will iterate through the message, pre-filling some caches for /// efficient access in the future. pub fn new(message: &'b Message) -> Result { - let mut indices = [(0u16, 0u16); 8]; - let mut edns_indices = [(0u16, 0u16); 8]; - - // DNS messages are 64KiB at the largest. - let _ = u16::try_from(message.as_bytes().len()) - .map_err(|_| ParseError)?; - - // The offset (in bytes) into the message contents. - let mut offset = 0; - - // The section counts from the message. - let counts = &message.header.counts.as_array(); - - // The offset of each section, in components. - let mut section_offsets = [0u16; 4]; + /// Parse the question section into cached offsets. + fn parse_questions( + contents: &[u8], + range: &mut Range, + number: u16, + indices: &mut [Range], + ) -> Result<(), ParseError> { + let mut indices = indices.iter_mut(); + let mut offset = range.start as usize; + + for _ in 0..number { + let (_question, rest) = + Question::<&UnparsedName>::split_message_bytes( + contents, offset, + )?; - // First, parse all questions. - for i in 0..counts[0].get() { - let (_question, rest) = - Question::<&'b UnparsedName>::split_message_bytes( - &message.contents, - offset, - )?; + if let Some(indices) = indices.next() { + *indices = offset as u16..rest as u16; + } - if let Some(indices) = indices.get_mut(i as usize) { - *indices = (offset as u16, rest as u16); + offset = rest; } - offset = rest; + range.end = offset as u16; + Ok(()) } - // The offset (in components) of this section in the message. - let mut section_offset = counts[0].get(); - section_offsets[0] = section_offset; - - // The offset of the EDNS record, if any. - let mut edns_offset = u16::MAX; - - // Parse all records. - for section in 1..4 { - for i in 0..counts[section].get() { + /// Parse a record section into cached offsets. + fn parse_records( + contents: &[u8], + section: u8, + range: &mut Range, + number: u16, + indices: &mut [Range], + edns_range: &mut Option>, + ) -> Result<(), ParseError> { + let mut indices = indices.iter_mut(); + let mut offset = range.start as usize; + + for _ in 0..number { let (record, rest) = Record::< - &'b UnparsedName, - &'b UnparsedRecordData, + &UnparsedName, + &UnparsedRecordData, >::split_message_bytes( - &message.contents, offset + contents, offset )?; + let data = offset + record.rname.len() + 10; + let range = offset as u16..data as u16; - let component = (section_offset + i) as usize; - if let Some(indices) = indices.get_mut(component) { - let data = offset + record.rname.len() + 10; - *indices = (offset as u16, data as u16); + if let Some(indices) = indices.next() { + *indices = range.clone(); } - if record.rtype == RType::OPT { - if edns_offset != u16::MAX { + if section == 3 && record.rtype == RType::OPT { + if edns_range.is_some() { // A DNS message can only contain one EDNS record. return Err(ParseError); - } else { - edns_offset = offset as u16; } + + *edns_range = Some(range); } offset = rest; } - section_offset += counts[section].get(); - section_offsets[section] = section_offset; + range.end = offset as u16; + Ok(()) } - // Parse EDNS options. - if edns_offset < u16::MAX { - // Extract the EDNS record data. - let offset = edns_offset as usize + 9; - let (&size, mut offset) = - <&U16>::split_message_bytes(&message.contents, offset)?; - - let contents = message - .contents - .get(..offset + size.get() as usize) - .ok_or(ParseError)?; - - // Parse through it. - let mut indices = edns_indices.iter_mut(); - while offset < contents.len() { + /// Parse the EDNS record into cached offsets. + fn parse_edns( + contents: &[u8], + range: Range, + number: &mut u16, + indices: &mut [Range], + ) -> Result<(), ParseError> { + let mut indices = indices.iter_mut(); + let mut offset = range.start as usize; + + while offset < range.end as usize { let (_type, rest) = <&U16>::split_message_bytes(contents, offset)?; let (_data, rest) = @@ -131,20 +126,80 @@ impl<'b> RequestMessage<'b> { contents, rest, )?; + *number += 1; + if let Some(indices) = indices.next() { - *indices = (offset as u16, rest as u16); + *indices = offset as u16..rest as u16; } offset = rest; } + + Ok(()) } - Ok(Self { + // DNS messages are 64KiB at the largest. + let _ = u16::try_from(message.as_bytes().len()) + .map_err(|_| ParseError)?; + + let mut this = Self { message, - indices, - edns_offset, - edns_indices, - section_offsets, - }) + questions: Default::default(), + answers: Default::default(), + authorities: Default::default(), + additional: Default::default(), + edns: Default::default(), + }; + + let mut edns_range = None; + + parse_questions( + &message.contents, + &mut this.questions.0, + message.header.counts.questions.get(), + &mut this.questions.1, + )?; + + this.answers.0 = this.questions.0.end..0; + parse_records( + &message.contents, + 1, + &mut this.answers.0, + message.header.counts.answers.get(), + &mut this.answers.1, + &mut edns_range, + )?; + + this.authorities.0 = this.answers.0.end..0; + parse_records( + &message.contents, + 2, + &mut this.authorities.0, + message.header.counts.authorities.get(), + &mut this.authorities.1, + &mut edns_range, + )?; + + this.additional.0 = this.authorities.0.end..0; + parse_records( + &message.contents, + 2, + &mut this.additional.0, + message.header.counts.additional.get(), + &mut this.additional.1, + &mut edns_range, + )?; + + if let Some(edns_range) = edns_range { + this.edns.0 = edns_range.clone(); + parse_edns( + &message.contents, + edns_range, + &mut this.edns.1, + &mut this.edns.2, + )?; + } + + Ok(this) } } From e0745c8bf2b1e3ec6bc63291bce8daab9a22d927 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 29 Jan 2025 12:04:17 +0100 Subject: [PATCH 119/167] [new_net/server] Impl 'RequestMessage::sole_question()' --- src/new_base/question.rs | 1 + src/new_net/server/request.rs | 43 ++++++++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 34e6dc282..fab29aaa8 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -59,6 +59,7 @@ where impl<'a, N> ParseMessageBytes<'a> for Question where + // TODO: Reduce to 'ParseMessageBytes'. N: SplitMessageBytes<'a>, { fn parse_message_bytes( diff --git a/src/new_net/server/request.rs b/src/new_net/server/request.rs index 86808a0f3..49e71ce6f 100644 --- a/src/new_net/server/request.rs +++ b/src/new_net/server/request.rs @@ -3,10 +3,10 @@ use core::ops::Range; use crate::new_base::{ - name::UnparsedName, + name::{Name, UnparsedName}, parse::{ParseMessageBytes, SplitMessageBytes}, - wire::{AsBytes, ParseError, SizePrefixed, U16}, - Message, Question, RType, Record, UnparsedRecordData, + wire::{AsBytes, ParseBytes, ParseError, SizePrefixed, U16}, + Message, QClass, QType, Question, RType, Record, UnparsedRecordData, }; /// A DNS request message. @@ -203,3 +203,40 @@ impl<'b> RequestMessage<'b> { Ok(this) } } + +//--- Inspection + +impl<'b> RequestMessage<'b> { + /// The sole question in the message. + /// + /// # Name Compression + /// + /// Due to the restrictions around compressed domain names (in order to + /// prevent attackers from crafting compression pointer loops), it is + /// guaranteed that the first QNAME in the message is uncompressed. + /// + /// # Errors + /// + /// Fails if there are zero or more than one question in the message. + pub fn sole_question(&self) -> Result, ParseError> { + if self.message.header.counts.questions.get() != 1 { + return Err(ParseError); + } + + // SAFETY: 'RequestMessage' is pre-validated. + let range = self.questions.1[0].clone(); + let range = range.start as usize..range.end as usize; + let contents = &self.message.contents[range]; + let qname = &contents[..contents.len() - 4]; + let qname = unsafe { Name::from_bytes_unchecked(qname) }; + let fields = &contents[contents.len() - 4..]; + let qtype = QType::parse_bytes(&fields[0..2]).unwrap(); + let qclass = QClass::parse_bytes(&fields[2..4]).unwrap(); + + Ok(Question { + qname, + qtype, + qclass, + }) + } +} From 64ea3481afec0d08d1d84c0be2f3c53bde02448f Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 29 Jan 2025 17:22:21 +0100 Subject: [PATCH 120/167] [new_net/server/request] Add the most important getters --- src/new_net/server/request.rs | 362 ++++++++++++++++++++++++++++++++-- 1 file changed, 344 insertions(+), 18 deletions(-) diff --git a/src/new_net/server/request.rs b/src/new_net/server/request.rs index 49e71ce6f..3f0150a2a 100644 --- a/src/new_net/server/request.rs +++ b/src/new_net/server/request.rs @@ -1,12 +1,17 @@ //! DNS request messages. -use core::ops::Range; - -use crate::new_base::{ - name::{Name, UnparsedName}, - parse::{ParseMessageBytes, SplitMessageBytes}, - wire::{AsBytes, ParseBytes, ParseError, SizePrefixed, U16}, - Message, QClass, QType, Question, RType, Record, UnparsedRecordData, +use core::{iter::FusedIterator, marker::PhantomData, ops::Range}; + +use crate::{ + new_base::{ + name::{Name, UnparsedName}, + parse::{ParseMessageBytes, SplitMessageBytes}, + wire::{AsBytes, ParseBytes, ParseError, SizePrefixed, U16}, + Message, ParseRecordData, QClass, QType, Question, RClass, RType, + Record, SectionCounts, UnparsedRecordData, TTL, + }, + new_edns::{EdnsOption, EdnsRecord}, + new_rdata::EdnsOptionsIter, }; /// A DNS request message. @@ -50,13 +55,14 @@ impl<'b> RequestMessage<'b> { let mut offset = range.start as usize; for _ in 0..number { - let (_question, rest) = + let (question, rest) = Question::<&UnparsedName>::split_message_bytes( contents, offset, )?; if let Some(indices) = indices.next() { - *indices = offset as u16..rest as u16; + let fields = offset + question.qname.len(); + *indices = offset as u16..fields as u16; } offset = rest; @@ -85,11 +91,10 @@ impl<'b> RequestMessage<'b> { >::split_message_bytes( contents, offset )?; - let data = offset + record.rname.len() + 10; - let range = offset as u16..data as u16; if let Some(indices) = indices.next() { - *indices = range.clone(); + let fields = offset + record.rname.len(); + *indices = offset as u16..fields as u16; } if section == 3 && record.rtype == RType::OPT { @@ -98,7 +103,7 @@ impl<'b> RequestMessage<'b> { return Err(ParseError); } - *edns_range = Some(range); + *edns_range = Some(offset as u16..rest as u16); } offset = rest; @@ -116,7 +121,7 @@ impl<'b> RequestMessage<'b> { indices: &mut [Range], ) -> Result<(), ParseError> { let mut indices = indices.iter_mut(); - let mut offset = range.start as usize; + let mut offset = range.start as usize + 11; while offset < range.end as usize { let (_type, rest) = @@ -204,6 +209,15 @@ impl<'b> RequestMessage<'b> { } } +//--- Internals + +impl<'b> RequestMessage<'b> { + /// The section counts. + fn counts(&self) -> &'b SectionCounts { + &self.message.header.counts + } +} + //--- Inspection impl<'b> RequestMessage<'b> { @@ -218,7 +232,7 @@ impl<'b> RequestMessage<'b> { /// # Errors /// /// Fails if there are zero or more than one question in the message. - pub fn sole_question(&self) -> Result, ParseError> { + pub fn sole_question(&self) -> Result, ParseError> { if self.message.header.counts.questions.get() != 1 { return Err(ParseError); } @@ -226,10 +240,9 @@ impl<'b> RequestMessage<'b> { // SAFETY: 'RequestMessage' is pre-validated. let range = self.questions.1[0].clone(); let range = range.start as usize..range.end as usize; - let contents = &self.message.contents[range]; - let qname = &contents[..contents.len() - 4]; + let qname = &self.message.contents[range.clone()]; let qname = unsafe { Name::from_bytes_unchecked(qname) }; - let fields = &contents[contents.len() - 4..]; + let fields = &self.message.contents[range.end..]; let qtype = QType::parse_bytes(&fields[0..2]).unwrap(); let qclass = QClass::parse_bytes(&fields[2..4]).unwrap(); @@ -239,4 +252,317 @@ impl<'b> RequestMessage<'b> { qclass, }) } + + /// The EDNS record in the message, if any. + pub fn edns_record(&self) -> Option> { + if self.edns.0.is_empty() { + return None; + } + + let range = self.edns.0.clone(); + let contents = &self.message.contents[..range.end as usize]; + EdnsRecord::parse_message_bytes(contents, range.start as usize) + .map(Some) + .expect("'RequestMessage' only holds well-formed EDNS records") + } + + /// The questions in the message. + /// + /// # Name Compression + /// + /// The returned questions use [`UnparsedName`] for the QNAMEs. These can + /// be resolved against the original message to determine the whole domain + /// name, if necessary. Note that decompression can fail. + pub fn questions(&self) -> RequestQuestions<'_, 'b> { + let contents = self.questions.0.clone(); + RequestQuestions { + message: self, + cache: self.questions.1.iter(), + contents: contents.start as usize..contents.end as usize, + indices: 0..self.counts().questions.get(), + } + } + + /// The answer records in the message. + /// + /// # Name Compression + /// + /// The returned records use [`UnparsedName`] for the RNAMEs. These can + /// be resolved against the original message to determine the whole domain + /// name, if necessary. Note that decompression can fail. + /// + /// # Record Data + /// + /// The caller can select an appropriate record data type to use. In most + /// cases, [`RecordData`](crate::new_rdata::RecordData) is appropriate; if + /// many records will be skipped, however, [`UnparsedRecordData`] might be + /// preferable. + pub fn answers(&self) -> RequestRecords<'_, 'b, D> + where + D: ParseRecordData<'b>, + { + let contents = self.answers.0.clone(); + RequestRecords { + message: self, + cache: self.answers.1.iter(), + contents: contents.start as usize..contents.end as usize, + indices: 0..self.counts().answers.get(), + _rdata: PhantomData, + } + } + + /// The authority records in the message. + /// + /// # Name Compression + /// + /// The returned records use [`UnparsedName`] for the RNAMEs. These can + /// be resolved against the original message to determine the whole domain + /// name, if necessary. Note that decompression can fail. + /// + /// # Record Data + /// + /// The caller can select an appropriate record data type to use. In most + /// cases, [`RecordData`](crate::new_rdata::RecordData) is appropriate; if + /// many records will be skipped, however, [`UnparsedRecordData`] might be + /// preferable. + pub fn authorities(&self) -> RequestRecords<'_, 'b, D> + where + D: ParseRecordData<'b>, + { + let contents = self.authorities.0.clone(); + RequestRecords { + message: self, + cache: self.authorities.1.iter(), + contents: contents.start as usize..contents.end as usize, + indices: 0..self.counts().authorities.get(), + _rdata: PhantomData, + } + } + + /// The additional records in the message. + /// + /// # Name Compression + /// + /// The returned records use [`UnparsedName`] for the RNAMEs. These can + /// be resolved against the original message to determine the whole domain + /// name, if necessary. Note that decompression can fail. + /// + /// # Record Data + /// + /// The caller can select an appropriate record data type to use. In most + /// cases, [`RecordData`](crate::new_rdata::RecordData) is appropriate; if + /// many records will be skipped, however, [`UnparsedRecordData`] might be + /// preferable. + pub fn additional(&self) -> RequestRecords<'_, 'b, D> + where + D: ParseRecordData<'b>, + { + let contents = self.additional.0.clone(); + RequestRecords { + message: self, + cache: self.additional.1.iter(), + contents: contents.start as usize..contents.end as usize, + indices: 0..self.counts().additional.get(), + _rdata: PhantomData, + } + } + + /// The EDNS options in the message. + pub fn edns_options(&self) -> RequestEdnsOptions<'b> { + let start = self.edns.0.start as usize + 11; + let end = self.edns.0.end as usize; + let options = &self.message.contents[start..end]; + RequestEdnsOptions { + inner: EdnsOptionsIter::new(options), + indices: 0..self.edns.1, + } + } } + +//----------- RequestQuestions ----------------------------------------------- + +/// The questions in a [`RequestMessage`]. +#[derive(Clone)] +pub struct RequestQuestions<'r, 'b> { + /// The underlying request message. + message: &'r RequestMessage<'b>, + + /// The cached question ranges. + cache: core::slice::Iter<'r, Range>, + + /// The range of message contents to parse. + contents: Range, + + /// The range of record indices left. + indices: Range, +} + +impl<'b> Iterator for RequestQuestions<'_, 'b> { + type Item = Question<&'b UnparsedName>; + + fn next(&mut self) -> Option { + // Try loading a cached question. + if let Some(range) = self.cache.next().cloned() { + if range.is_empty() { + // There are no more questions, stop. + self.cache = Default::default(); + self.contents.start = self.contents.end; + return None; + } + + // SAFETY: 'RequestMessage' is pre-validated. + let range = range.start as usize..range.end as usize; + let qname = &self.message.message.contents[range.clone()]; + let qname = unsafe { UnparsedName::from_bytes_unchecked(qname) }; + let fields = &self.message.message.contents[range.end..]; + let qtype = QType::parse_bytes(&fields[0..2]).unwrap(); + let qclass = QClass::parse_bytes(&fields[2..4]).unwrap(); + + self.indices.start += 1; + return Some(Question { + qname, + qtype, + qclass, + }); + } + + let _ = self.indices.next()?; + let contents = &self.message.message.contents[..self.contents.end]; + let (question, rest) = + Question::split_message_bytes(contents, self.contents.start) + .expect("'RequestMessage' only contains valid questions"); + + self.contents.start = rest; + Some(question) + } +} + +impl ExactSizeIterator for RequestQuestions<'_, '_> { + fn len(&self) -> usize { + self.indices.len() + } +} + +impl FusedIterator for RequestQuestions<'_, '_> {} + +//----------- RequestRecords ------------------------------------------------- + +/// The records in a section of a [`RequestMessage`]. +#[derive(Clone)] +pub struct RequestRecords<'r, 'b, D> { + /// The underlying request message. + message: &'r RequestMessage<'b>, + + /// The cached record ranges. + cache: core::slice::Iter<'r, Range>, + + /// The range of message contents to parse. + contents: Range, + + /// The range of record indices left. + indices: Range, + + /// A representation of the record data held. + _rdata: PhantomData<&'r [D]>, +} + +impl<'b, D> Iterator for RequestRecords<'_, 'b, D> +where + D: ParseRecordData<'b>, +{ + type Item = Result, ParseError>; + + fn next(&mut self) -> Option { + // Try loading a cached record. + if let Some(range) = self.cache.next().cloned() { + if range.is_empty() { + // There are no more records, stop. + self.cache = Default::default(); + self.contents.start = self.contents.end; + return None; + } + + // SAFETY: 'RequestMessage' is pre-validated. + let range = range.start as usize..range.end as usize; + let rname = &self.message.message.contents[range.clone()]; + let rname = unsafe { UnparsedName::from_bytes_unchecked(rname) }; + let fields = &self.message.message.contents[range.end..]; + let rtype = RType::parse_bytes(&fields[0..2]).unwrap(); + let rclass = RClass::parse_bytes(&fields[2..4]).unwrap(); + let ttl = TTL::parse_bytes(&fields[4..8]).unwrap(); + let size = U16::parse_bytes(&fields[8..10]).unwrap(); + let rdata_end = range.end + 10 + size.get() as usize; + let rdata = &self.message.message.contents[..rdata_end]; + let rdata = + match D::parse_record_data(rdata, range.end + 10, rtype) { + Ok(rdata) => rdata, + Err(err) => return Some(Err(err)), + }; + + self.indices.start += 1; + return Some(Ok(Record { + rname, + rtype, + rclass, + ttl, + rdata, + })); + } + + let _ = self.indices.next()?; + let contents = &self.message.message.contents[..self.contents.end]; + let (record, rest) = match Record::split_message_bytes( + contents, + self.contents.start, + ) { + Ok((record, rest)) => (record, rest), + Err(err) => return Some(Err(err)), + }; + + self.contents.start = rest; + Some(Ok(record)) + } +} + +impl<'b, D> ExactSizeIterator for RequestRecords<'_, 'b, D> +where + D: ParseRecordData<'b>, +{ + fn len(&self) -> usize { + self.indices.len() + } +} + +impl<'b, D> FusedIterator for RequestRecords<'_, 'b, D> where + D: ParseRecordData<'b> +{ +} + +//----------- RequestEdnsOptions --------------------------------------------- + +/// The EDNS options in a [`RequestMessage`]. +#[derive(Clone)] +pub struct RequestEdnsOptions<'b> { + /// The underlying iterator. + inner: EdnsOptionsIter<'b>, + + /// The range of option indices left. + indices: Range, +} + +impl<'b> Iterator for RequestEdnsOptions<'b> { + type Item = EdnsOption<'b>; + + fn next(&mut self) -> Option { + let _ = self.indices.next()?; + self.inner.next().map(Result::unwrap) + } +} + +impl ExactSizeIterator for RequestEdnsOptions<'_> { + fn len(&self) -> usize { + self.indices.len() + } +} + +impl FusedIterator for RequestEdnsOptions<'_> {} From 133c8e989e04aafc5c372de90e28f62ee9475908 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 29 Jan 2025 17:30:13 +0100 Subject: [PATCH 121/167] [new_net/server] Fix broken doc links --- src/new_net/server/mod.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/new_net/server/mod.rs b/src/new_net/server/mod.rs index cf8b7d358..3d9745581 100644 --- a/src/new_net/server/mod.rs +++ b/src/new_net/server/mod.rs @@ -75,7 +75,7 @@ pub trait LocalService { /// A producer of DNS responses. /// /// This type returns components to insert in a DNS response message. It - /// is constructed by [`Self::respond()`]. + /// is constructed by [`Self::respond_local()`]. /// /// # Lifetimes /// @@ -149,8 +149,8 @@ pub trait LocalServiceLayer { /// A producer of DNS responses. /// /// This type returns components to insert in a DNS response message. It - /// is constructed by [`Self::respond()`], if a response is returned early - /// (without the wrapped service interacting with it). + /// is constructed by [`Self::respond_local()`], if a response is returned + /// early (without the wrapped service interacting with it). /// /// # Lifetimes /// @@ -162,7 +162,8 @@ pub trait LocalServiceLayer { /// /// This type modifies the response from the wrapped service, by adding, /// removing, or modifying the components of the response message. It is - /// constructed by [`Self::respond()`], if an early return does not occur. + /// constructed by [`Self::respond_local()`], if an early return does not + /// occur. /// /// # Lifetimes /// @@ -180,8 +181,7 @@ pub trait LocalServiceLayer { /// the wrapped service, [`ControlFlow::Break`] is returned. /// /// The returned future does not implement [`Send`]. Use [`ServiceLayer`] - /// and [`ServiceLayer::respond_local()`] for a [`Send`]-implementing - /// version. + /// and [`ServiceLayer::respond()`] for a [`Send`]-implementing version. #[allow(async_fn_in_trait)] async fn respond_local( &self, From 44125fccfa57fccc0c91077175beb9cdef2e6132 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 29 Jan 2025 17:37:00 +0100 Subject: [PATCH 122/167] [new_net/server/request] Make method docs clearer --- src/new_net/server/mod.rs | 2 +- src/new_net/server/request.rs | 27 +++++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/new_net/server/mod.rs b/src/new_net/server/mod.rs index 3d9745581..e02688d62 100644 --- a/src/new_net/server/mod.rs +++ b/src/new_net/server/mod.rs @@ -23,7 +23,7 @@ use crate::new_base::{ mod impls; -mod request; +pub mod request; pub use request::RequestMessage; //----------- Service -------------------------------------------------------- diff --git a/src/new_net/server/request.rs b/src/new_net/server/request.rs index 3f0150a2a..8a0eadc48 100644 --- a/src/new_net/server/request.rs +++ b/src/new_net/server/request.rs @@ -268,6 +268,8 @@ impl<'b> RequestMessage<'b> { /// The questions in the message. /// + /// An iterator of [`Question`]s is returned. + /// /// # Name Compression /// /// The returned questions use [`UnparsedName`] for the QNAMEs. These can @@ -285,6 +287,8 @@ impl<'b> RequestMessage<'b> { /// The answer records in the message. /// + /// An iterator of [`Record`]s is returned. + /// /// # Name Compression /// /// The returned records use [`UnparsedName`] for the RNAMEs. These can @@ -313,6 +317,8 @@ impl<'b> RequestMessage<'b> { /// The authority records in the message. /// + /// An iterator of [`Record`]s is returned. + /// /// # Name Compression /// /// The returned records use [`UnparsedName`] for the RNAMEs. These can @@ -341,6 +347,8 @@ impl<'b> RequestMessage<'b> { /// The additional records in the message. /// + /// An iterator of [`Record`]s is returned. + /// /// # Name Compression /// /// The returned records use [`UnparsedName`] for the RNAMEs. These can @@ -368,6 +376,17 @@ impl<'b> RequestMessage<'b> { } /// The EDNS options in the message. + /// + /// An iterator of (results of) [`EdnsOption`]s is returned. + /// + /// If the message does not contain an EDNS record, this is empty. + /// + /// # Errors + /// + /// While the overall structure of the EDNS record is verified when the + /// [`RequestMessage`] is constructed, option-specific validation (e.g. + /// checking the size of an EDNS cookie option) is not performed early. + /// Those errors will be caught and returned by the iterator. pub fn edns_options(&self) -> RequestEdnsOptions<'b> { let start = self.edns.0.start as usize + 11; let end = self.edns.0.end as usize; @@ -541,6 +560,10 @@ impl<'b, D> FusedIterator for RequestRecords<'_, 'b, D> where //----------- RequestEdnsOptions --------------------------------------------- /// The EDNS options in a [`RequestMessage`]. +/// +/// This is a wrapper around [`EdnsOptionsIter`] that also implements +/// [`ExactSizeIterator`], as the total number of EDNS options is cached in +/// the [`RequestMessage`]. #[derive(Clone)] pub struct RequestEdnsOptions<'b> { /// The underlying iterator. @@ -551,11 +574,11 @@ pub struct RequestEdnsOptions<'b> { } impl<'b> Iterator for RequestEdnsOptions<'b> { - type Item = EdnsOption<'b>; + type Item = Result, ParseError>; fn next(&mut self) -> Option { let _ = self.indices.next()?; - self.inner.next().map(Result::unwrap) + self.inner.next() } } From 5aadd4bdc512c904104c7af8f211df9fd7e1940f Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 30 Jan 2025 10:59:52 +0100 Subject: [PATCH 123/167] [new_net/server] Define a simple UDP transport --- src/new_net/server/mod.rs | 2 + src/new_net/server/transport/mod.rs | 117 ++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 src/new_net/server/transport/mod.rs diff --git a/src/new_net/server/mod.rs b/src/new_net/server/mod.rs index e02688d62..4c33c254d 100644 --- a/src/new_net/server/mod.rs +++ b/src/new_net/server/mod.rs @@ -26,6 +26,8 @@ mod impls; pub mod request; pub use request::RequestMessage; +pub mod transport; + //----------- Service -------------------------------------------------------- /// A (multi-threaded) DNS service, that computes responses for requests. diff --git a/src/new_net/server/transport/mod.rs b/src/new_net/server/transport/mod.rs new file mode 100644 index 000000000..677d5a01b --- /dev/null +++ b/src/new_net/server/transport/mod.rs @@ -0,0 +1,117 @@ +//! Network transports for DNS servers. + +use core::net::SocketAddr; +use std::{io, sync::Arc, vec::Vec}; + +use tokio::net::UdpSocket; + +use crate::{ + new_base::{ + build::{BuilderContext, MessageBuilder}, + wire::{AsBytes, ParseBytesByRef}, + Message, + }, + new_net::server::{ProduceMessage, RequestMessage}, +}; + +use super::Service; + +//----------- serve_udp() ---------------------------------------------------- + +/// Serve DNS requests over UDP. +/// +/// A UDP socket will be bound to the given address and listened on for DNS +/// requests. Requests will be handed to the given [`Service`] and responses +/// will be returned directly. Each DNS request is handed off to a Tokio task +/// so they can respond asynchronously. +pub async fn serve_udp( + addr: SocketAddr, + service: impl Service + Send + 'static, +) -> io::Result<()> { + /// Internal multi-threaded state. + struct State { + /// The UDP socket serving DNS. + socket: UdpSocket, + + /// The service implementing response logic. + service: S, + } + + impl State { + /// Respond to a particular UDP request. + async fn respond( + self: Arc, + mut buffer: Vec, + peer: SocketAddr, + ) { + let Ok(message) = Message::parse_bytes_by_ref(&buffer) else { + // This message is fundamentally invalid, just give up. + return; + }; + + let Ok(request) = RequestMessage::new(message) else { + // This message is malformed; inform the peer and stop. + let mut buffer = [0u8; 12]; + let response = Message::parse_bytes_by_mut(&mut buffer) + .expect("Any 12-byte or larger buffer is a 'Message'"); + response.header.id = message.header.id; + response.header.flags = message.header.flags.respond(1); + let response = response.slice_to(0); + let _ = self.socket.send_to(response.as_bytes(), peer).await; + return; + }; + + // Generate the appropriate response. + let mut producer = self.service.respond(&request).await; + + // Build up the response message. + + buffer.clear(); + buffer.resize(65536, 0); + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(&mut buffer, &mut context); + + producer.header(builder.header_mut()); + while let Some(question) = producer.next_question(&mut builder) { + question.commit(); + } + while let Some(answer) = producer.next_answer(&mut builder) { + answer.commit(); + } + while let Some(authority) = producer.next_authority(&mut builder) + { + authority.commit(); + } + while let Some(additional) = + producer.next_additional(&mut builder) + { + additional.commit(); + } + + // Send the response back to the peer. + let _ = self + .socket + .send_to(builder.message().as_bytes(), peer) + .await; + } + } + + // Generate internal state. + let state = Arc::new(State { + socket: UdpSocket::bind(addr).await?, + service, + }); + + // Main loop: wait on new requests. + loop { + // Allocate a buffer for the request. + let mut buffer = vec![0u8; 65536]; + + // Receive a DNS request. + let (size, peer) = state.socket.recv_from(&mut buffer).await?; + buffer.truncate(size); + + // Spawn a Tokio task to respond to the request. + tokio::task::spawn(state.clone().respond(buffer, peer)); + } +} From 6fe81d9ad6d339253836722aab68983d4857d3ca Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 30 Jan 2025 19:17:12 +0100 Subject: [PATCH 124/167] Move 'new_net::server' to 'new_server' --- src/lib.rs | 3 ++- src/new_net/mod.rs | 6 ------ src/{new_net/server => new_server}/impls.rs | 0 src/{new_net/server => new_server}/mod.rs | 0 src/{new_net/server => new_server}/request.rs | 0 src/{new_net/server => new_server}/transport/mod.rs | 13 +++++-------- 6 files changed, 7 insertions(+), 15 deletions(-) delete mode 100644 src/new_net/mod.rs rename src/{new_net/server => new_server}/impls.rs (100%) rename src/{new_net/server => new_server}/mod.rs (100%) rename src/{new_net/server => new_server}/request.rs (100%) rename src/{new_net/server => new_server}/transport/mod.rs (94%) diff --git a/src/lib.rs b/src/lib.rs index 0709aa1a3..b8e02d032 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -212,5 +212,6 @@ pub mod zonetree; pub mod new_base; pub mod new_edns; -pub mod new_net; pub mod new_rdata; + +pub mod new_server; diff --git a/src/new_net/mod.rs b/src/new_net/mod.rs deleted file mode 100644 index a28ed0b4b..000000000 --- a/src/new_net/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -//! Sending and receiving DNS messages. - -#![cfg(feature = "net")] -#![cfg_attr(docsrs, doc(cfg(feature = "net")))] - -pub mod server; diff --git a/src/new_net/server/impls.rs b/src/new_server/impls.rs similarity index 100% rename from src/new_net/server/impls.rs rename to src/new_server/impls.rs diff --git a/src/new_net/server/mod.rs b/src/new_server/mod.rs similarity index 100% rename from src/new_net/server/mod.rs rename to src/new_server/mod.rs diff --git a/src/new_net/server/request.rs b/src/new_server/request.rs similarity index 100% rename from src/new_net/server/request.rs rename to src/new_server/request.rs diff --git a/src/new_net/server/transport/mod.rs b/src/new_server/transport/mod.rs similarity index 94% rename from src/new_net/server/transport/mod.rs rename to src/new_server/transport/mod.rs index 677d5a01b..e7aa9350b 100644 --- a/src/new_net/server/transport/mod.rs +++ b/src/new_server/transport/mod.rs @@ -5,16 +5,13 @@ use std::{io, sync::Arc, vec::Vec}; use tokio::net::UdpSocket; -use crate::{ - new_base::{ - build::{BuilderContext, MessageBuilder}, - wire::{AsBytes, ParseBytesByRef}, - Message, - }, - new_net::server::{ProduceMessage, RequestMessage}, +use crate::new_base::{ + build::{BuilderContext, MessageBuilder}, + wire::{AsBytes, ParseBytesByRef}, + Message, }; -use super::Service; +use super::{ProduceMessage, RequestMessage, Service}; //----------- serve_udp() ---------------------------------------------------- From 7a4fa11ccd5f570a5a5bf2d4251f1d41cf27f630 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 3 Feb 2025 10:44:00 +0100 Subject: [PATCH 125/167] [new_server] Overhaul API for more flexibility - Layers can now modify requests before passing them forward. - Layers can now see the entire request, allowing them to ensure that it is entirely consistent. Previously, they could only see one piece at a time, in order, preventing them from modifying a record based on the contents of a record following it. - Layers can now truncate a message when they see fit. - 'ParsedMessage' is used to represent requests and responses; it will allocate on the heap, but it provides extremely fast traversal of the entire message (as every message component is directly indexable). - EDNS support is hardcoded in 'ParsedMessage'; this could make it brittle to future changes in the DNS specs, but service layers rely on EDNS and need to be able to find EDNS options quickly. - Trait objects are better supported, in terms of performance; there will only be two method calls to every 'dyn ServiceLayer'. The old API would have one method call for every message component in the request and the response. - The 'Service' and 'ServiceLayer' traits no longer have associated types, and are now possible to use as trait objects directly. --- Cargo.lock | 2 +- Cargo.toml | 7 +- src/new_base/build/message.rs | 19 +- src/new_base/build/question.rs | 8 +- src/new_base/build/record.rs | 8 +- src/new_base/question.rs | 2 +- src/new_base/record.rs | 2 +- src/new_edns/mod.rs | 2 +- src/new_rdata/edns.rs | 8 + src/new_rdata/mod.rs | 34 + src/new_server/exchange.rs | 451 ++++++++++++ src/new_server/impls.rs | 1163 ++++++------------------------- src/new_server/mod.rs | 352 ++-------- src/new_server/request.rs | 591 ---------------- src/new_server/transport/mod.rs | 70 +- 15 files changed, 803 insertions(+), 1916 deletions(-) create mode 100644 src/new_server/exchange.rs delete mode 100644 src/new_server/request.rs diff --git a/Cargo.lock b/Cargo.lock index c53a0803b..ce92a567a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -232,10 +232,10 @@ version = "0.10.3" dependencies = [ "arbitrary", "arc-swap", + "bumpalo", "bytes", "chrono", "domain-macros", - "either", "futures-util", "hashbrown 0.14.5", "heapless", diff --git a/Cargo.toml b/Cargo.toml index 071e17a21..6ea514ec9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,9 +22,8 @@ license = "BSD-3-Clause" [dependencies] domain-macros = { path = "./macros", version = "0.10.3" } -either = { version = "1.10.0", default-features = false } - arbitrary = { version = "1.4.1", optional = true, features = ["derive"] } +bumpalo = { version = "3.12", optional = true } octseq = { version = "0.5.2", default-features = false } time = { version = "0.3.1", default-features = false } rand = { version = "0.8", optional = true } @@ -60,7 +59,7 @@ bytes = ["dep:bytes", "octseq/bytes"] heapless = ["dep:heapless", "octseq/heapless"] serde = ["dep:serde", "octseq/serde"] smallvec = ["dep:smallvec", "octseq/smallvec"] -std = ["dep:hashbrown", "either/use_std", "bytes?/std", "octseq/std", "time/std"] +std = ["dep:hashbrown", "bytes?/std", "octseq/std", "time/std"] tracing = ["dep:log", "dep:tracing"] # Cryptographic backends @@ -76,7 +75,7 @@ zonefile = ["bytes", "serde", "std"] # Unstable features unstable-client-transport = ["moka", "net", "tracing"] -unstable-server-transport = ["arc-swap", "chrono/clock", "libc", "net", "siphasher", "tracing"] +unstable-server-transport = ["dep:bumpalo", "arc-swap", "chrono/clock", "libc", "net", "siphasher", "tracing"] unstable-sign = ["std", "dep:secrecy", "unstable-validate", "time/formatting"] unstable-stelline = ["tokio/test-util", "tracing", "tracing-subscriber", "tsig", "unstable-client-transport", "unstable-server-transport", "zonefile"] unstable-validate = ["bytes", "std", "ring"] diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 5e969115f..6f5311383 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -19,17 +19,17 @@ use super::{ /// This is a high-level building interface, offering methods to put together /// entire questions and records. It directly writes into an allocated buffer /// (on the stack or the heap). -pub struct MessageBuilder<'b> { +pub struct MessageBuilder<'b, 'c> { /// The message being constructed. pub(super) message: &'b mut Message, /// Context for building. - pub(super) context: &'b mut BuilderContext, + pub(super) context: &'c mut BuilderContext, } //--- Initialization -impl<'b> MessageBuilder<'b> { +impl<'b, 'c> MessageBuilder<'b, 'c> { /// Initialize an empty [`MessageBuilder`]. /// /// The message header is left uninitialized. use [`Self::header_mut()`] @@ -41,7 +41,7 @@ impl<'b> MessageBuilder<'b> { /// possible size for a DNS message). pub fn new( buffer: &'b mut [u8], - context: &'b mut BuilderContext, + context: &'c mut BuilderContext, ) -> Self { let message = Message::parse_bytes_by_mut(buffer) .expect("The caller's buffer is at least 12 bytes big"); @@ -52,7 +52,7 @@ impl<'b> MessageBuilder<'b> { //--- Inspection -impl MessageBuilder<'_> { +impl MessageBuilder<'_, '_> { /// The message header. pub fn header(&self) -> &Header { &self.message.header @@ -86,9 +86,14 @@ impl MessageBuilder<'_> { //--- Interaction -impl MessageBuilder<'_> { +impl<'b> MessageBuilder<'b, '_> { + /// End the builder, returning the built message. + pub fn finish(self) -> &'b Message { + self.message + } + /// Reborrow the builder with a shorter lifetime. - pub fn reborrow(&mut self) -> MessageBuilder<'_> { + pub fn reborrow(&mut self) -> MessageBuilder<'_, '_> { MessageBuilder { message: self.message, context: self.context, diff --git a/src/new_base/build/question.rs b/src/new_base/build/question.rs index 95fa095ae..addd72d91 100644 --- a/src/new_base/build/question.rs +++ b/src/new_base/build/question.rs @@ -19,7 +19,7 @@ use super::{BuildCommitted, BuildIntoMessage, MessageBuilder, MessageState}; /// commit (finish building) or cancel (remove) the question. pub struct QuestionBuilder<'b> { /// The underlying message builder. - builder: MessageBuilder<'b>, + builder: MessageBuilder<'b, 'b>, /// The offset of the question name. name: u16, @@ -33,7 +33,7 @@ impl<'b> QuestionBuilder<'b> { /// The provided builder must be empty (i.e. must not have uncommitted /// content). pub(super) fn build( - mut builder: MessageBuilder<'b>, + mut builder: MessageBuilder<'b, 'b>, question: &Question, ) -> Result { // TODO: Require that the QNAME serialize correctly? @@ -51,7 +51,7 @@ impl<'b> QuestionBuilder<'b> { /// `builder.message().contents[name..]` must represent a valid /// [`Question`] in the wire format. pub unsafe fn from_raw_parts( - builder: MessageBuilder<'b>, + builder: MessageBuilder<'b, 'b>, name: u16, ) -> Self { Self { builder, name } @@ -84,7 +84,7 @@ impl<'b> QuestionBuilder<'b> { } /// Deconstruct this [`QuestionBuilder`] into its raw parts. - pub fn into_raw_parts(self) -> (MessageBuilder<'b>, u16) { + pub fn into_raw_parts(self) -> (MessageBuilder<'b, 'b>, u16) { (self.builder, self.name) } } diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index f74418a13..7d9a48b5d 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -23,7 +23,7 @@ use super::{ /// cancel (remove) the record. pub struct RecordBuilder<'b> { /// The underlying message builder. - builder: MessageBuilder<'b>, + builder: MessageBuilder<'b, 'b>, /// The offset of the record name. name: u16, @@ -40,7 +40,7 @@ impl<'b> RecordBuilder<'b> { /// The provided builder must be empty (i.e. must not have uncommitted /// content). pub(super) fn build( - mut builder: MessageBuilder<'b>, + mut builder: MessageBuilder<'b, 'b>, record: &Record, ) -> Result where @@ -97,7 +97,7 @@ impl<'b> RecordBuilder<'b> { /// [`Record`] in the wire format. `contents[data..]` must represent the /// record data (i.e. immediately after the record data size field). pub unsafe fn from_raw_parts( - builder: MessageBuilder<'b>, + builder: MessageBuilder<'b, 'b>, name: u16, data: u16, ) -> Self { @@ -151,7 +151,7 @@ impl<'b> RecordBuilder<'b> { } /// Deconstruct this [`RecordBuilder`] into its raw parts. - pub fn into_raw_parts(self) -> (MessageBuilder<'b>, u16, u16) { + pub fn into_raw_parts(self) -> (MessageBuilder<'b, 'b>, u16, u16) { let (name, data) = (self.name, self.data); let this = ManuallyDrop::new(self); let this = (&*this) as *const Self; diff --git a/src/new_base/question.rs b/src/new_base/question.rs index fab29aaa8..dbcc2e3c9 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -12,7 +12,7 @@ use super::{ //----------- Question ------------------------------------------------------- /// A DNS question. -#[derive(Clone, BuildBytes, ParseBytes, SplitBytes)] +#[derive(Clone, Debug, BuildBytes, ParseBytes, SplitBytes)] pub struct Question { /// The domain name being requested. pub qname: N, diff --git a/src/new_base/record.rs b/src/new_base/record.rs index badbab04e..eef4c2840 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -15,7 +15,7 @@ use super::{ //----------- Record --------------------------------------------------------- /// A DNS record. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Record { /// The name of the record. pub rname: N, diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index a0640dac1..f2cf6b710 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -191,7 +191,7 @@ impl fmt::Debug for EdnsFlags { //----------- EdnsOption ----------------------------------------------------- /// An Extended DNS option. -#[derive(Debug)] +#[derive(Clone, Debug)] #[non_exhaustive] pub enum EdnsOption<'b> { /// A client's request for a DNS cookie. diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index 43327b50c..d832d9f5a 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -26,6 +26,14 @@ pub struct Opt { contents: [u8], } +//--- Associated Constants + +impl Opt { + /// Empty OPT record data. + pub const EMPTY: &'static Self = + unsafe { core::mem::transmute(&[] as &'static [u8]) }; +} + //--- Inspection impl Opt { diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 70f041240..8491a3c6d 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -63,6 +63,40 @@ pub enum RecordData<'a, N> { Unknown(RType, &'a UnknownRecordData), } +impl<'a, N> RecordData<'a, N> { + /// Transform the compressed domain names in this record data. + pub fn map_names R>( + self, + mut f: F, + ) -> RecordData<'a, R> { + match self { + Self::A(r) => RecordData::A(r), + Self::Ns(r) => RecordData::Ns(Ns { name: (f)(r.name) }), + Self::CName(r) => RecordData::CName(CName { name: (f)(r.name) }), + Self::Soa(r) => RecordData::Soa(Soa { + mname: (f)(r.mname), + rname: (f)(r.rname), + serial: r.serial, + refresh: r.refresh, + retry: r.retry, + expire: r.expire, + minimum: r.minimum, + }), + Self::Wks(r) => RecordData::Wks(r), + Self::Ptr(r) => RecordData::Ptr(Ptr { name: (f)(r.name) }), + Self::HInfo(r) => RecordData::HInfo(r), + Self::Mx(r) => RecordData::Mx(Mx { + preference: r.preference, + exchange: (f)(r.exchange), + }), + Self::Txt(r) => RecordData::Txt(r), + Self::Aaaa(r) => RecordData::Aaaa(r), + Self::Opt(r) => RecordData::Opt(r), + Self::Unknown(t, r) => RecordData::Unknown(t, r), + } + } +} + //--- Parsing record data impl<'a, N> ParseRecordData<'a> for RecordData<'a, N> diff --git a/src/new_server/exchange.rs b/src/new_server/exchange.rs new file mode 100644 index 000000000..3b303a360 --- /dev/null +++ b/src/new_server/exchange.rs @@ -0,0 +1,451 @@ +//! Request-response exchanges for DNS servers. + +use core::any::{Any, TypeId}; +use std::{boxed::Box, time::SystemTime, vec::Vec}; + +use bumpalo::Bump; + +use crate::{ + new_base::{ + build::{BuilderContext, MessageBuilder}, + name::{RevName, RevNameBuf}, + parse::SplitMessageBytes, + wire::{BuildBytes, ParseError, TruncationError, U16}, + HeaderFlags, Message, Question, RType, Record, + }, + new_edns::EdnsOption, + new_rdata::{Opt, RecordData}, +}; + +//----------- Exchange ------------------------------------------------------- + +/// A DNS request-response exchange. +/// +/// An [`Exchange`] represents a request sent to a DNS server and the server's +/// response (as it is being built). It tracks basic information about the +/// request, such as when it was sent and the connection it originates from, +/// as well as metadata stored by layers in the DNS server. +pub struct Exchange<'a> { + /// An allocator for storing parts of the message. + pub alloc: Allocator<'a>, + + /// When the exchange began (i.e. when the request was received). + pub reception: SystemTime, + + /// The request message. + pub request: ParsedMessage<'a>, + + /// The response message being built. + pub response: ParsedMessage<'a>, + + /// Dynamic metadata stored by the DNS server. + pub metadata: Vec, +} + +//----------- OutgoingResponse ----------------------------------------------- + +/// An [`Exchange`] with an initialized response message. +pub struct OutgoingResponse<'e, 'a> { + /// An allocator for storing parts of the message. + pub alloc: &'e mut Allocator<'a>, + + /// The response message being built. + pub response: &'e mut ParsedMessage<'a>, + + /// Dynamic metadata stored by the DNS server. + pub metadata: &'e mut Vec, +} + +impl<'e, 'a> OutgoingResponse<'e, 'a> { + /// Construct an [`OutgoingResponse`] on an [`Exchange`]. + pub fn new(exchange: &'e mut Exchange<'a>) -> Self { + Self { + alloc: &mut exchange.alloc, + response: &mut exchange.response, + metadata: &mut exchange.metadata, + } + } + + /// Reborrow this response for a shorter lifetime. + pub fn reborrow(&mut self) -> OutgoingResponse<'_, 'a> { + OutgoingResponse { + alloc: self.alloc, + response: self.response, + metadata: self.metadata, + } + } +} + +//----------- ParsedMessage -------------------------------------------------- + +/// A pre-parsed DNS message. +/// +/// This is a simple representation of DNS messages outside the wire format, +/// making it easy to inspect and modify them efficiently. +#[derive(Clone, Default, Debug)] +pub struct ParsedMessage<'a> { + /// The message ID. + pub id: U16, + + /// The message flags. + pub flags: HeaderFlags, + + /// Questions in the message. + pub questions: Vec>, + + /// Answer records in the message. + pub answers: Vec>>, + + /// Authority records in the message. + pub authorities: Vec>>, + + /// Additional records in the message. + /// + /// If there is an EDNS record, it will be included here, but its record + /// data (which contains the EDNS options) will be empty. The options are + /// stored in the `options` field for easier access. + pub additional: Vec>>, + + /// EDNS options in the message. + /// + /// These options will be appended to the EDNS record in the additional + /// section (there must be one for any options to exist). The order of + /// the options is meaningless. + pub options: Vec>, +} + +impl<'a> ParsedMessage<'a> { + /// Parse an existing [`Message`]. + /// + /// Decompressed domain names are allocated using the given [`Bump`]. + pub fn parse( + message: &'a Message, + alloc: &mut Allocator<'a>, + ) -> Result { + type ParsedQuestion = Question; + type ParsedRecord<'a> = + Record>; + + /// Map a domain name by placing it in a [`Bump`]. + fn map_name<'a>( + name: RevNameBuf, + alloc: &mut Allocator<'a>, + ) -> &'a RevName { + // Allocate the domain name. + let name = alloc.alloc_slice_copy(name.as_bytes()); + // SAFETY: 'name' has the same bytes as the input 'name'. + unsafe { RevName::from_bytes_unchecked(name) } + } + + let mut this = Self::default(); + let mut offset = 0; + + // Parse the message header. + this.id = message.header.id; + this.flags = message.header.flags; + let counts = message.header.counts; + + // Parse the question section. + this.questions + .reserve(counts.questions.get().max(256) as usize); + for _ in 0..counts.questions.get() { + let (question, rest) = ParsedQuestion::split_message_bytes( + &message.contents, + offset, + )?; + + this.questions.push(Question { + qname: map_name(question.qname, alloc), + qtype: question.qtype, + qclass: question.qclass, + }); + offset = rest; + } + + // Parse the answer section. + this.answers.reserve(counts.answers.get().max(256) as usize); + for _ in 0..counts.answers.get() { + let (answer, rest) = + ParsedRecord::split_message_bytes(&message.contents, offset)?; + + this.answers.push(Record { + rname: map_name(answer.rname, alloc), + rtype: answer.rtype, + rclass: answer.rclass, + ttl: answer.ttl, + rdata: answer.rdata.map_names(|n| map_name(n, alloc)), + }); + offset = rest; + } + + // Parse the authority section. + this.authorities + .reserve(counts.authorities.get().max(256) as usize); + for _ in 0..counts.authorities.get() { + let (authority, rest) = + ParsedRecord::split_message_bytes(&message.contents, offset)?; + + this.authorities.push(Record { + rname: map_name(authority.rname, alloc), + rtype: authority.rtype, + rclass: authority.rclass, + ttl: authority.ttl, + rdata: authority.rdata.map_names(|n| map_name(n, alloc)), + }); + offset = rest; + } + + // The EDNS record data. + let mut edns_data = None; + + // Parse the additional section. + this.additional + .reserve(counts.additional.get().max(256) as usize); + for _ in 0..counts.additional.get() { + let (mut additional, rest) = + ParsedRecord::split_message_bytes(&message.contents, offset)?; + + if let RecordData::Opt(opt) = additional.rdata { + if edns_data.is_some() { + // A message cannot contain two distinct EDNS records. + return Err(ParseError); + } + + edns_data = Some(opt); + + // Deduplicate the EDNS data. + additional.rdata = RecordData::Opt(Opt::EMPTY); + } + + this.additional.push(Record { + rname: map_name(additional.rname, alloc), + rtype: additional.rtype, + rclass: additional.rclass, + ttl: additional.ttl, + rdata: additional.rdata.map_names(|n| map_name(n, alloc)), + }); + offset = rest; + } + + // Ensure there's no other content in the message. + if offset != message.contents.len() { + return Err(ParseError); + } + + // Parse EDNS options. + if let Some(edns_data) = edns_data { + for option in edns_data.options() { + this.options.push(option?); + } + } + + Ok(this) + } + + /// Build this message into the given buffer. + /// + /// If the message was too large, a [`TruncationError`] is returned. + pub fn build<'b>( + &self, + buffer: &'b mut [u8], + ) -> Result<&'b Message, TruncationError> { + // Construct a 'MessageBuilder'. + if buffer.len() < 12 { + return Err(TruncationError); + } + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(buffer, &mut context); + + // Build the message header. + let header = builder.header_mut(); + header.id = self.id; + header.flags = self.flags; + header.counts.questions.set(self.questions.len() as u16); + header.counts.answers.set(self.answers.len() as u16); + header.counts.authorities.set(self.authorities.len() as u16); + header.counts.additional.set(self.additional.len() as u16); + + // Build the question section. + for question in &self.questions { + builder + .build_question(question)? + .expect("No answers, authorities, or additionals are built"); + } + + // Build the answer section. + for answer in &self.answers { + builder + .build_answer(answer)? + .expect("No authorities, or additionals are built"); + } + + // Build the authority section. + for authority in &self.authorities { + builder + .build_authority(authority)? + .expect("No additionals are built"); + } + + // Build the additional section. + let mut edns_built = false; + for additional in &self.additional { + if additional.rtype == RType::OPT { + // Technically, multiple OPT records are an error. But this + // isn't the right place to report that. + debug_assert!(!edns_built, "Multiple EDNS records found"); + + let mut builder = builder.build_additional(additional)?; + let mut delegate = builder.delegate(); + let mut uninit = delegate.uninitialized(); + for option in &self.options { + uninit = option.build_bytes(uninit)?; + } + let uninit_len = uninit.len(); + let appended = delegate.uninitialized().len() - uninit_len; + delegate.mark_appended(appended); + core::mem::drop(delegate); + builder.commit(); + + edns_built = true; + continue; + } + + builder.build_additional(additional)?; + } + + debug_assert!( + self.options.is_empty() || edns_built, + "EDNS options found, but no OPT record", + ); + + Ok(builder.finish()) + } +} + +impl ParsedMessage<'_> { + /// Reset this object to a blank message. + /// + /// This is helpful in order to reuse the underlying allocations. + pub fn reset(&mut self) { + self.id = U16::new(0); + self.flags = HeaderFlags::default(); + self.questions.clear(); + self.answers.clear(); + self.authorities.clear(); + self.additional.clear(); + self.options.clear(); + } +} + +//----------- Metadata ------------------------------------------------------- + +/// Arbitrary metadata about a DNS exchange. +/// +/// This is an enhanced version of `Box` that can +/// perform downcasting more efficiently. +pub struct Metadata { + /// The type ID of the object. + type_id: TypeId, + + /// The underlying object. + object: Box, +} + +impl Metadata { + /// Wrap an object in [`Metadata`]. + pub fn new(object: T) -> Self { + let type_id = TypeId::of::(); + let object = Box::new(object) as Box; + Self { type_id, object } + } + + /// Check whether this is metadata of a certain type. + pub fn is(&self) -> bool { + self.type_id == TypeId::of::() + } + + /// Try downcasting to a reference of a particular type. + pub fn try_as(&self) -> Option<&T> { + if !self.is::() { + return None; + } + + let pointer: *const (dyn Any + Send + 'static) = &*self.object; + // SAFETY: 'pointer' was created by 'Box::into_raw()', and thus is + // safe to dereference (the pointer will only be dropped when 'self' + // is, but that cannot happen during the current lifetime). + Some(unsafe { &*pointer.cast::() }) + } + + /// Try downcasting to a mutable reference of a particular type. + pub fn try_as_mut(&mut self) -> Option<&mut T> { + if !self.is::() { + return None; + } + + let pointer: *mut (dyn Any + Send + 'static) = &mut *self.object; + // SAFETY: 'pointer' was created by 'Box::into_raw()', and thus is + // safe to dereference (the pointer will only be dropped when 'self' + // is, but that cannot happen during the current lifetime). + Some(unsafe { &mut *pointer.cast::() }) + } + + /// Try moving this object out of the [`Metadata`]. + pub fn try_into(self) -> Result { + if !self.is::() { + return Err(self); + } + + let pointer: *mut _ = Box::into_raw(self.object); + // SAFETY: 'pointer' was created by 'Box::into_raw()', and thus is + // safe to move into the same 'Box'. + Ok(*unsafe { Box::from_raw(pointer.cast::()) }) + } +} + +//----------- Allocator ------------------------------------------------------ + +/// A bump allocator with a fixed lifetime. +/// +/// This is a wrapper around [`bumpalo::Bump`] that guarantees thread safety. +#[derive(Debug)] +#[repr(transparent)] +pub struct Allocator<'a> { + /// The underlying allocator. + /// + /// In order to share access to a [`Bump`], even on a single thread, it + /// must be a shared reference (`&'a Bump`). That is how we store it + /// here. However, we guarantee that the [`Allocator`] is constructed + /// from a mutable reference -- thus that this is the only reference to + /// the bump allocator. It is never exposed publicly, so it cannot be + /// copied and used from multiple threads. + inner: &'a Bump, +} + +impl<'a> Allocator<'a> { + /// Construct a new [`Allocator`]. + pub const fn new(inner: &'a mut Bump) -> Self { + // NOTE: The 'Bump' is mutably borrowed for lifetime 'a; the reference + // we store is thus guaranteed to be unique. + Self { inner } + } + + /// Allocate an object. + pub fn alloc(&mut self, val: T) -> &'a mut T { + self.inner.alloc(val) + } + + /// Allocate a slice and copy the given contents into it. + pub fn alloc_slice_copy(&mut self, src: &[T]) -> &'a mut [T] { + self.inner.alloc_slice_copy(src) + } +} + +// SAFETY: An 'Allocator' contains '&Bump', which is '!Send' because 'Bump' is +// '!Sync'. However, we guarantee that there are no other references to the +// 'Bump' -- that this is really '&mut Bump' (which is 'Send'). +unsafe impl Send for Allocator<'_> {} + +// NOTE: 'Allocator' acts a bit like the nightly-only 'std::sync::Exclusive', +// since it doesn't provide any shared access to the underlying 'Bump'. It is +// sound for it to implement 'Sync', but we defer this until necessary. diff --git a/src/new_server/impls.rs b/src/new_server/impls.rs index 88293d24a..ffabe4d3a 100644 --- a/src/new_server/impls.rs +++ b/src/new_server/impls.rs @@ -5,43 +5,36 @@ use core::ops::ControlFlow; #[cfg(feature = "std")] use std::{boxed::Box, rc::Rc, sync::Arc, vec::Vec}; -use either::Either::{self, Left, Right}; - -use crate::new_base::{ - build::{MessageBuilder, QuestionBuilder, RecordBuilder}, - Header, -}; - use super::{ - LocalService, LocalServiceLayer, ProduceMessage, RequestMessage, Service, - ServiceLayer, TransformMessage, + exchange::OutgoingResponse, Exchange, LocalService, LocalServiceLayer, + Service, ServiceLayer, }; //----------- impl Service --------------------------------------------------- impl Service for &T { - async fn respond(&self, request: &RequestMessage<'_>) -> Self::Producer { - T::respond(self, request).await + async fn respond(&self, exchange: &mut Exchange<'_>) { + T::respond(self, exchange).await } } impl Service for &mut T { - async fn respond(&self, request: &RequestMessage<'_>) -> Self::Producer { - T::respond(self, request).await + async fn respond(&self, exchange: &mut Exchange<'_>) { + T::respond(self, exchange).await } } #[cfg(feature = "std")] impl Service for Box { - async fn respond(&self, request: &RequestMessage<'_>) -> Self::Producer { - T::respond(self, request).await + async fn respond(&self, exchange: &mut Exchange<'_>) { + T::respond(self, exchange).await } } #[cfg(feature = "std")] impl Service for Arc { - async fn respond(&self, request: &RequestMessage<'_>) -> Self::Producer { - T::respond(self, request).await + async fn respond(&self, exchange: &mut Exchange<'_>) { + T::respond(self, exchange).await } } @@ -50,12 +43,11 @@ where A: ServiceLayer, S: Service, { - async fn respond(&self, request: &RequestMessage<'_>) -> Self::Producer { - match self.0.respond(request).await { - ControlFlow::Continue(t) => { - Right((t, self.1.respond(request).await)) - } - ControlFlow::Break(p) => Left(p), + async fn respond(&self, exchange: &mut Exchange<'_>) { + if self.0.process_incoming(exchange).await.is_continue() { + self.1.respond(exchange).await; + let response = OutgoingResponse::new(exchange); + self.0.process_outgoing(response).await; } } } @@ -68,13 +60,10 @@ macro_rules! impl_service_tuple { $($layers: ServiceLayer,)* $service: Service, { - async fn respond( - &self, - request: &RequestMessage<'_>, - ) -> Self::Producer { + async fn respond(&self, exchange: &mut Exchange<'_>) { #[allow(non_snake_case)] let ($($layers,)* $service,) = self; - (($($layers),*,), $service).respond(request).await + (($($layers),*,), $service).respond(exchange).await } } }; @@ -94,60 +83,35 @@ impl_service_tuple!(A B C D E F G H I J K..S); //----------- impl LocalService ---------------------------------------------- impl LocalService for &T { - type Producer = T::Producer; - - async fn respond_local( - &self, - request: &RequestMessage<'_>, - ) -> Self::Producer { - T::respond_local(self, request).await + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + T::respond_local(self, exchange).await } } impl LocalService for &mut T { - type Producer = T::Producer; - - async fn respond_local( - &self, - request: &RequestMessage<'_>, - ) -> Self::Producer { - T::respond_local(self, request).await + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + T::respond_local(self, exchange).await } } #[cfg(feature = "std")] impl LocalService for Box { - type Producer = T::Producer; - - async fn respond_local( - &self, - request: &RequestMessage<'_>, - ) -> Self::Producer { - T::respond_local(self, request).await + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + T::respond_local(self, exchange).await } } #[cfg(feature = "std")] impl LocalService for Rc { - type Producer = T::Producer; - - async fn respond_local( - &self, - request: &RequestMessage<'_>, - ) -> Self::Producer { - T::respond_local(self, request).await + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + T::respond_local(self, exchange).await } } #[cfg(feature = "std")] impl LocalService for Arc { - type Producer = T::Producer; - - async fn respond_local( - &self, - request: &RequestMessage<'_>, - ) -> Self::Producer { - T::respond_local(self, request).await + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + T::respond_local(self, exchange).await } } @@ -156,17 +120,11 @@ where A: LocalServiceLayer, S: LocalService, { - type Producer = Either; - - async fn respond_local( - &self, - request: &RequestMessage<'_>, - ) -> Self::Producer { - match self.0.respond_local(request).await { - ControlFlow::Continue(t) => { - Right((t, self.1.respond_local(request).await)) - } - ControlFlow::Break(p) => Left(p), + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + if self.0.process_local_incoming(exchange).await.is_continue() { + self.1.respond_local(exchange).await; + let response = OutgoingResponse::new(exchange); + self.0.process_local_outgoing(response).await; } } } @@ -179,16 +137,10 @@ macro_rules! impl_local_service_tuple { $($layers: LocalServiceLayer,)* $service: LocalService, { - type Producer = - <(($($layers),*,), $service) as LocalService>::Producer; - - async fn respond_local( - &self, - request: &RequestMessage<'_>, - ) -> Self::Producer { + async fn respond_local(&self, exchange: &mut Exchange<'_>) { #[allow(non_snake_case)] let ($($layers,)* $service,) = self; - (($($layers),*,), $service).respond_local(request).await + (($($layers),*,), $service).respond_local(exchange).await } } }; @@ -208,40 +160,56 @@ impl_local_service_tuple!(A B C D E F G H I J K..S); //----------- impl ServiceLayer ---------------------------------------------- impl ServiceLayer for &T { - async fn respond( + async fn process_incoming( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - T::respond(self, request).await + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_incoming(self, exchange).await + } + + async fn process_outgoing(&self, response: OutgoingResponse<'_, '_>) { + T::process_outgoing(self, response).await } } impl ServiceLayer for &mut T { - async fn respond( + async fn process_incoming( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - T::respond(self, request).await + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_incoming(self, exchange).await + } + + async fn process_outgoing(&self, response: OutgoingResponse<'_, '_>) { + T::process_outgoing(self, response).await } } #[cfg(feature = "std")] impl ServiceLayer for Box { - async fn respond( + async fn process_incoming( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - T::respond(self, request).await + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_incoming(self, exchange).await + } + + async fn process_outgoing(&self, response: OutgoingResponse<'_, '_>) { + T::process_outgoing(self, response).await } } #[cfg(feature = "std")] impl ServiceLayer for Arc { - async fn respond( + async fn process_incoming( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - T::respond(self, request).await + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_incoming(self, exchange).await + } + + async fn process_outgoing(&self, response: OutgoingResponse<'_, '_>) { + T::process_outgoing(self, response).await } } @@ -250,19 +218,17 @@ where A: ServiceLayer, B: ServiceLayer, { - async fn respond( + async fn process_incoming( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - let at = match ::respond(&self.0, request).await { - ControlFlow::Continue(at) => at, - ControlFlow::Break(ap) => return ControlFlow::Break(Left(ap)), - }; + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + self.0.process_incoming(exchange).await?; + self.1.process_incoming(exchange).await + } - match ::respond(&self.1, request).await { - ControlFlow::Continue(bt) => ControlFlow::Continue((at, bt)), - ControlFlow::Break(bp) => ControlFlow::Break(Right((at, bp))), - } + async fn process_outgoing(&self, mut response: OutgoingResponse<'_, '_>) { + self.1.process_outgoing(response.reborrow()).await; + self.0.process_outgoing(response.reborrow()).await } } @@ -275,14 +241,26 @@ macro_rules! impl_service_layer_tuple { $($middle: ServiceLayer,)+ $last: ServiceLayer, { - async fn respond( + async fn process_incoming( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { #[allow(non_snake_case)] let ($first, $($middle,)+ ref $last) = self; - ($first, ($($middle,)+ $last)).respond(request).await + $first.process_incoming(exchange).await?; + $($middle.process_incoming(exchange).await?;)+ + $last.process_incoming(exchange).await + } + + async fn process_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) { + #[allow(non_snake_case)] + let ($first, $($middle,)+ ref $last) = self; + ($first, ($($middle,)+ $last)) + .process_outgoing(response).await } } } @@ -301,100 +279,119 @@ impl_service_layer_tuple!(A..L: B C D E F G H I J K); #[cfg(feature = "std")] impl ServiceLayer for [T] { - async fn respond( + async fn process_incoming( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - let mut transformers = Vec::new(); + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { for layer in self { - match layer.respond(request).await { - ControlFlow::Continue(t) => transformers.push(t), - ControlFlow::Break(p) => { - return ControlFlow::Break((transformers.into(), p)); - } - } + layer.process_incoming(exchange).await?; + } + ControlFlow::Continue(()) + } + + async fn process_outgoing(&self, mut response: OutgoingResponse<'_, '_>) { + for layer in self.iter().rev() { + layer.process_outgoing(response.reborrow()).await; } - ControlFlow::Continue(transformers.into()) } } #[cfg(feature = "std")] impl ServiceLayer for Vec { - async fn respond( + async fn process_incoming( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - self.as_slice().respond(request).await + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + self.as_slice().process_incoming(exchange).await + } + + async fn process_outgoing(&self, response: OutgoingResponse<'_, '_>) { + self.as_slice().process_outgoing(response).await } } //----------- impl LocalServiceLayer ----------------------------------------- impl LocalServiceLayer for &T { - type Producer = T::Producer; - - type Transformer = T::Transformer; + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_local_incoming(self, exchange).await + } - async fn respond_local( + async fn process_local_outgoing( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - T::respond_local(self, request).await + response: OutgoingResponse<'_, '_>, + ) { + T::process_local_outgoing(self, response).await } } impl LocalServiceLayer for &mut T { - type Producer = T::Producer; - - type Transformer = T::Transformer; + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_local_incoming(self, exchange).await + } - async fn respond_local( + async fn process_local_outgoing( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - T::respond_local(self, request).await + response: OutgoingResponse<'_, '_>, + ) { + T::process_local_outgoing(self, response).await } } #[cfg(feature = "std")] impl LocalServiceLayer for Box { - type Producer = T::Producer; - - type Transformer = T::Transformer; + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_local_incoming(self, exchange).await + } - async fn respond_local( + async fn process_local_outgoing( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - T::respond_local(self, request).await + response: OutgoingResponse<'_, '_>, + ) { + T::process_local_outgoing(self, response).await } } #[cfg(feature = "std")] impl LocalServiceLayer for Rc { - type Producer = T::Producer; - - type Transformer = T::Transformer; + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_local_incoming(self, exchange).await + } - async fn respond_local( + async fn process_local_outgoing( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - T::respond_local(self, request).await + response: OutgoingResponse<'_, '_>, + ) { + T::process_local_outgoing(self, response).await } } #[cfg(feature = "std")] impl LocalServiceLayer for Arc { - type Producer = T::Producer; - - type Transformer = T::Transformer; + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_local_incoming(self, exchange).await + } - async fn respond_local( + async fn process_local_outgoing( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - T::respond_local(self, request).await + response: OutgoingResponse<'_, '_>, + ) { + T::process_local_outgoing(self, response).await } } @@ -403,23 +400,21 @@ where A: LocalServiceLayer, B: LocalServiceLayer, { - type Producer = Either; - - type Transformer = (A::Transformer, B::Transformer); - - async fn respond_local( + async fn process_local_incoming( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - let at = match self.0.respond_local(request).await { - ControlFlow::Continue(at) => at, - ControlFlow::Break(ap) => return ControlFlow::Break(Left(ap)), - }; + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + self.0.process_local_incoming(exchange).await?; + self.1.process_local_incoming(exchange).await?; + ControlFlow::Continue(()) + } - match self.1.respond_local(request).await { - ControlFlow::Continue(bt) => ControlFlow::Continue((at, bt)), - ControlFlow::Break(bp) => ControlFlow::Break(Right((at, bp))), - } + async fn process_local_outgoing( + &self, + mut response: OutgoingResponse<'_, '_>, + ) { + self.1.process_local_outgoing(response.reborrow()).await; + self.0.process_local_outgoing(response.reborrow()).await } } @@ -432,22 +427,25 @@ macro_rules! impl_local_service_layer_tuple { $($middle: LocalServiceLayer,)+ $last: LocalServiceLayer, { - type Producer = - <($first, ($($middle,)+ $last)) as LocalServiceLayer> - ::Producer; - - type Transformer = - <($first, ($($middle,)+ $last)) as LocalServiceLayer> - ::Transformer; + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_> + ) -> ControlFlow<()> { + #[allow(non_snake_case)] + let ($first, $($middle,)+ ref $last) = self; + $first.process_local_incoming(exchange).await?; + $($middle.process_local_incoming(exchange).await?;)+ + $last.process_local_incoming(exchange).await + } - async fn respond_local( + async fn process_local_outgoing( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow - { + response: OutgoingResponse<'_, '_> + ) { #[allow(non_snake_case)] let ($first, $($middle,)+ ref $last) = self; - ($first, ($($middle,)+ $last)).respond_local(request).await + ($first, ($($middle,)+ $last)) + .process_local_outgoing(response).await } } } @@ -466,784 +464,39 @@ impl_local_service_layer_tuple!(A..L: B C D E F G H I J K); #[cfg(feature = "std")] impl LocalServiceLayer for [T] { - type Producer = (Box<[T::Transformer]>, T::Producer); - type Transformer = Box<[T::Transformer]>; - - async fn respond_local( + async fn process_local_incoming( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - let mut transformers = Vec::new(); + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { for layer in self { - match layer.respond_local(request).await { - ControlFlow::Continue(t) => transformers.push(t), - ControlFlow::Break(p) => { - return ControlFlow::Break((transformers.into(), p)); - } - } - } - ControlFlow::Continue(transformers.into()) - } -} - -#[cfg(feature = "std")] -impl LocalServiceLayer for Vec { - type Producer = (Box<[T::Transformer]>, T::Producer); - type Transformer = Box<[T::Transformer]>; - - async fn respond_local( - &self, - request: &RequestMessage<'_>, - ) -> ControlFlow { - self.as_slice().respond_local(request).await - } -} - -//----------- impl ProduceMessage -------------------------------------------- - -impl ProduceMessage for &mut T { - fn header(&mut self, header: &mut Header) { - T::header(self, header); - } - - fn next_question<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - T::next_question(self, builder) - } - - fn next_answer<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - T::next_answer(self, builder) - } - - fn next_authority<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - T::next_authority(self, builder) - } - - fn next_additional<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - T::next_additional(self, builder) - } -} - -#[cfg(feature = "std")] -impl ProduceMessage for Box { - fn header(&mut self, header: &mut Header) { - T::header(self, header); - } - - fn next_question<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - T::next_question(self, builder) - } - - fn next_answer<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - T::next_answer(self, builder) - } - - fn next_authority<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - T::next_authority(self, builder) - } - - fn next_additional<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - T::next_additional(self, builder) - } -} - -impl ProduceMessage for Either -where - A: ProduceMessage, - B: ProduceMessage, -{ - fn header(&mut self, header: &mut Header) { - either::for_both!(self, x => x.header(header)); - } - - fn next_question<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - either::for_both!(self, x => x.next_question(builder)) - } - - fn next_answer<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - either::for_both!(self, x => x.next_answer(builder)) - } - - fn next_authority<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - either::for_both!(self, x => x.next_authority(builder)) - } - - fn next_additional<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - either::for_both!(self, x => x.next_additional(builder)) - } -} - -impl ProduceMessage for (A, B) -where - A: TransformMessage, - B: ProduceMessage, -{ - fn header(&mut self, header: &mut Header) { - B::header(&mut self.1, header); - A::modify_header(&mut self.0, header); - } - - fn next_question<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - loop { - let mut delegate = builder.reborrow(); - let mut question = match self.1.next_question(&mut delegate) { - Some(question) => question, - None => break, - }; - - if self.0.modify_question(&mut question).is_break() { - continue; - } - - return builder.resume_question(); - } - - self.0.next_question(builder) - } - - fn next_answer<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - loop { - let mut delegate = builder.reborrow(); - let mut answer = match self.1.next_answer(&mut delegate) { - Some(answer) => answer, - None => break, - }; - - if self.0.modify_answer(&mut answer).is_break() { - continue; - } - - core::mem::drop(answer); - return builder.resume_answer(); - } - - self.0.next_answer(builder) - } - - fn next_authority<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - loop { - let mut delegate = builder.reborrow(); - let mut authority = match self.1.next_authority(&mut delegate) { - Some(authority) => authority, - None => break, - }; - - if self.0.modify_authority(&mut authority).is_break() { - continue; - } - - core::mem::drop(authority); - return builder.resume_authority(); - } - - self.0.next_authority(builder) - } - - fn next_additional<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - loop { - let mut delegate = builder.reborrow(); - let mut additional = match self.1.next_additional(&mut delegate) { - Some(additional) => additional, - None => break, - }; - - if self.0.modify_additional(&mut additional).is_break() { - continue; - } - - core::mem::drop(additional); - return builder.resume_additional(); + layer.process_local_incoming(exchange).await?; } - - self.0.next_additional(builder) - } -} - -macro_rules! impl_produce_message_tuple { - ($first:ident .. $last:ident: $($middle:ident)*) => { - impl<$first, $($middle,)* $last: ?Sized> - ProduceMessage for ($first, $($middle,)* $last) - where - $first: TransformMessage, - $($middle: TransformMessage,)* - $last: ProduceMessage, - { - fn header(&mut self, header: &mut Header) { - #[allow(non_snake_case)] - let ($first, $($middle,)* ref mut $last) = self; - ($first, ($($middle,)* $last)) - .header(header) - } - - fn next_question<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - #[allow(non_snake_case)] - let ($first, $($middle,)* ref mut $last) = self; - ($first, ($($middle,)* $last)) - .next_question(builder) - } - - fn next_answer<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - #[allow(non_snake_case)] - let ($first, $($middle,)* ref mut $last) = self; - ($first, ($($middle,)* $last)) - .next_answer(builder) - } - - fn next_authority<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - #[allow(non_snake_case)] - let ($first, $($middle,)* ref mut $last) = self; - ($first, ($($middle,)* $last)) - .next_authority(builder) - } - - fn next_additional<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - #[allow(non_snake_case)] - let ($first, $($middle,)* ref mut $last) = self; - ($first, ($($middle,)* $last)) - .next_additional(builder) - } - } - }; -} - -impl_produce_message_tuple!(A..C: B); -impl_produce_message_tuple!(A..D: B C); -impl_produce_message_tuple!(A..E: B C D); -impl_produce_message_tuple!(A..F: B C D E); -impl_produce_message_tuple!(A..G: B C D E F); -impl_produce_message_tuple!(A..H: B C D E F G); -impl_produce_message_tuple!(A..I: B C D E F G H); -impl_produce_message_tuple!(A..J: B C D E F G H I); -impl_produce_message_tuple!(A..K: B C D E F G H I J); -impl_produce_message_tuple!(A..L: B C D E F G H I J K); - -impl ProduceMessage for [T] { - fn header(&mut self, header: &mut Header) { - if let [ref mut layers @ .., ref mut last] = self { - last.header(header); - for layer in layers.iter_mut().rev() { - layer.modify_header(header); - } - } - } - - fn next_question<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - let mut layers = self; - while let [ref mut nested @ .., ref mut last] = layers { - let mut delegate = builder.reborrow(); - let mut question = match last.next_question(&mut delegate) { - Some(question) => question, - None => break, - }; - - match nested - .iter_mut() - .rev() - .try_for_each(|layer| layer.modify_question(&mut question)) - { - ControlFlow::Continue(()) => { - return builder.resume_question(); - } - - ControlFlow::Break(()) => {} - } - - layers = nested; - } - - None - } - - fn next_answer<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - let mut layers = self; - while let [ref mut nested @ .., ref mut last] = layers { - let mut delegate = builder.reborrow(); - let mut answer = match last.next_answer(&mut delegate) { - Some(answer) => answer, - None => break, - }; - - match nested - .iter_mut() - .rev() - .try_for_each(|layer| layer.modify_answer(&mut answer)) - { - ControlFlow::Continue(()) => { - core::mem::drop(answer); - return builder.resume_answer(); - } - - ControlFlow::Break(()) => {} - } - - layers = nested; - } - - None - } - - fn next_authority<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - let mut layers = self; - while let [ref mut nested @ .., ref mut last] = layers { - let mut delegate = builder.reborrow(); - let mut authority = match last.next_authority(&mut delegate) { - Some(authority) => authority, - None => break, - }; - - match nested - .iter_mut() - .rev() - .try_for_each(|layer| layer.modify_authority(&mut authority)) - { - ControlFlow::Continue(()) => { - core::mem::drop(authority); - return builder.resume_authority(); - } - - ControlFlow::Break(()) => {} - } - - layers = nested; - } - - None - } - - fn next_additional<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - let mut layers = self; - while let [ref mut nested @ .., ref mut last] = layers { - let mut delegate = builder.reborrow(); - let mut additional = match last.next_additional(&mut delegate) { - Some(additional) => additional, - None => break, - }; - - match nested.iter_mut().rev().try_for_each(|layer| { - layer.modify_additional(&mut additional) - }) { - ControlFlow::Continue(()) => { - core::mem::drop(additional); - return builder.resume_additional(); - } - - ControlFlow::Break(()) => {} - } - - layers = nested; - } - - None - } -} - -#[cfg(feature = "std")] -impl ProduceMessage for Vec { - fn header(&mut self, header: &mut Header) { - self.as_mut_slice().header(header) - } - - fn next_question<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - self.as_mut_slice().next_question(builder) - } - - fn next_answer<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - self.as_mut_slice().next_answer(builder) - } - - fn next_authority<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - self.as_mut_slice().next_authority(builder) - } - - fn next_additional<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - self.as_mut_slice().next_additional(builder) - } -} - -//----------- TransformMessage ----------------------------------------------- - -impl TransformMessage for &mut T { - fn modify_header(&mut self, header: &mut Header) { - T::modify_header(self, header); - } - - fn modify_question( - &mut self, - builder: &mut QuestionBuilder<'_>, - ) -> ControlFlow<()> { - T::modify_question(self, builder) - } - - fn modify_answer( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - T::modify_answer(self, builder) - } - - fn modify_authority( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - T::modify_authority(self, builder) - } - - fn modify_additional( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - T::modify_additional(self, builder) - } -} - -#[cfg(feature = "std")] -impl TransformMessage for Box { - fn modify_header(&mut self, header: &mut Header) { - T::modify_header(self, header); - } - - fn modify_question( - &mut self, - builder: &mut QuestionBuilder<'_>, - ) -> ControlFlow<()> { - T::modify_question(self, builder) - } - - fn modify_answer( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - T::modify_answer(self, builder) - } - - fn modify_authority( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - T::modify_authority(self, builder) - } - - fn modify_additional( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - T::modify_additional(self, builder) - } -} - -impl TransformMessage for Either -where - A: TransformMessage, - B: TransformMessage, -{ - fn modify_header(&mut self, header: &mut Header) { - either::for_both!(self, x => x.modify_header(header)); - } - - fn modify_question( - &mut self, - builder: &mut QuestionBuilder<'_>, - ) -> ControlFlow<()> { - either::for_both!(self, x => x.modify_question(builder)) - } - - fn modify_answer( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - either::for_both!(self, x => x.modify_answer(builder)) - } - - fn modify_authority( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - either::for_both!(self, x => x.modify_authority(builder)) - } - - fn modify_additional( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - either::for_both!(self, x => x.modify_additional(builder)) - } -} - -impl TransformMessage for (A, B) -where - A: TransformMessage, - B: TransformMessage, -{ - fn modify_header(&mut self, header: &mut Header) { - self.1.modify_header(header); - self.0.modify_header(header); - } - - fn modify_question( - &mut self, - builder: &mut QuestionBuilder<'_>, - ) -> ControlFlow<()> { - self.1.modify_question(builder)?; - self.0.modify_question(builder)?; - ControlFlow::Continue(()) - } - - fn modify_answer( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - self.1.modify_answer(builder)?; - self.0.modify_answer(builder)?; - ControlFlow::Continue(()) - } - - fn modify_authority( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - self.1.modify_authority(builder)?; - self.0.modify_authority(builder)?; - ControlFlow::Continue(()) - } - - fn modify_additional( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - self.1.modify_additional(builder)?; - self.0.modify_additional(builder)?; ControlFlow::Continue(()) } -} - -macro_rules! impl_transform_message_tuple { - ($first:ident .. $last:ident: $($middle:ident)*) => { - impl<$first, $($middle,)* $last: ?Sized> - TransformMessage for ($first, $($middle,)* $last) - where - $first: TransformMessage, - $($middle: TransformMessage,)* - $last: TransformMessage, - { - fn modify_header(&mut self, header: &mut Header) { - #[allow(non_snake_case)] - let ($first, $($middle,)* ref mut $last) = self; - ($first, ($($middle,)* $last)) - .modify_header(header) - } - fn modify_question( - &mut self, - builder: &mut QuestionBuilder<'_>, - ) -> ControlFlow<()> { - #[allow(non_snake_case)] - let ($first, $($middle,)* ref mut $last) = self; - ($first, ($($middle,)* $last)) - .modify_question(builder) - } - - fn modify_answer( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - #[allow(non_snake_case)] - let ($first, $($middle,)* ref mut $last) = self; - ($first, ($($middle,)* $last)) - .modify_answer(builder) - } - - fn modify_authority( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - #[allow(non_snake_case)] - let ($first, $($middle,)* ref mut $last) = self; - ($first, ($($middle,)* $last)) - .modify_authority(builder) - } - - fn modify_additional( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - #[allow(non_snake_case)] - let ($first, $($middle,)* ref mut $last) = self; - ($first, ($($middle,)* $last)) - .modify_additional(builder) - } + async fn process_local_outgoing( + &self, + mut response: OutgoingResponse<'_, '_>, + ) { + for layer in self.iter().rev() { + layer.process_local_outgoing(response.reborrow()).await; } - }; -} - -impl_transform_message_tuple!(A..C: B); -impl_transform_message_tuple!(A..D: B C); -impl_transform_message_tuple!(A..E: B C D); -impl_transform_message_tuple!(A..F: B C D E); -impl_transform_message_tuple!(A..G: B C D E F); -impl_transform_message_tuple!(A..H: B C D E F G); -impl_transform_message_tuple!(A..I: B C D E F G H); -impl_transform_message_tuple!(A..J: B C D E F G H I); -impl_transform_message_tuple!(A..K: B C D E F G H I J); -impl_transform_message_tuple!(A..L: B C D E F G H I J K); - -impl TransformMessage for [T] { - fn modify_header(&mut self, header: &mut Header) { - self.iter_mut() - .rev() - .for_each(|layer| layer.modify_header(header)); - } - - fn modify_question( - &mut self, - builder: &mut QuestionBuilder<'_>, - ) -> ControlFlow<()> { - self.iter_mut() - .rev() - .try_for_each(|layer| layer.modify_question(builder)) - } - - fn modify_answer( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - self.iter_mut() - .rev() - .try_for_each(|layer| layer.modify_answer(builder)) - } - - fn modify_authority( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - self.iter_mut() - .rev() - .try_for_each(|layer| layer.modify_authority(builder)) - } - - fn modify_additional( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - self.iter_mut() - .rev() - .try_for_each(|layer| layer.modify_additional(builder)) } } #[cfg(feature = "std")] -impl TransformMessage for Vec { - fn modify_header(&mut self, header: &mut Header) { - self.as_mut_slice().modify_header(header) - } - - fn modify_question( - &mut self, - builder: &mut QuestionBuilder<'_>, - ) -> ControlFlow<()> { - self.as_mut_slice().modify_question(builder) - } - - fn modify_answer( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - self.as_mut_slice().modify_answer(builder) - } - - fn modify_authority( - &mut self, - builder: &mut RecordBuilder<'_>, +impl LocalServiceLayer for Vec { + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, ) -> ControlFlow<()> { - self.as_mut_slice().modify_authority(builder) + self.as_slice().process_local_incoming(exchange).await } - fn modify_additional( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - self.as_mut_slice().modify_additional(builder) + async fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) { + self.as_slice().process_local_outgoing(response).await } } diff --git a/src/new_server/mod.rs b/src/new_server/mod.rs index 4c33c254d..ce8bd749e 100644 --- a/src/new_server/mod.rs +++ b/src/new_server/mod.rs @@ -16,15 +16,11 @@ use core::{future::Future, ops::ControlFlow}; -use crate::new_base::{ - build::{MessageBuilder, QuestionBuilder, RecordBuilder}, - Header, -}; - mod impls; -pub mod request; -pub use request::RequestMessage; +pub mod exchange; +pub use exchange::Exchange; +use exchange::OutgoingResponse; pub mod transport; @@ -43,19 +39,15 @@ pub mod transport; /// Additional functionality can be added to a service by prefixing it with /// service layers, usually in a tuple. A number of blanket implementations /// are provided to simplify this. -pub trait Service: LocalService + Sync { +pub trait Service: LocalService + Sync { /// Respond to a DNS request. /// - /// The provided consumer must have been provided the entire DNS request - /// message. This method will use the extracted information to formulate - /// a response message, in the form of a producer type. - /// - /// The returned future implements [`Send`]. Use [`LocalService`] and - /// [`LocalService::respond_local()`] if [`Send`] is not necessary. + /// The returned [`Future`] is thread-safe; it implements [`Send`]. Use + /// [`LocalService::respond_local()`] if this is not necessary. fn respond( &self, - request: &RequestMessage<'_>, - ) -> impl Future + Send; + exchange: &mut Exchange<'_>, + ) -> impl Future + Send; } //----------- LocalService --------------------------------------------------- @@ -74,30 +66,14 @@ pub trait Service: LocalService + Sync { /// service layers, usually in a tuple. A number of blanket implementations /// are provided to simplify this. pub trait LocalService { - /// A producer of DNS responses. - /// - /// This type returns components to insert in a DNS response message. It - /// is constructed by [`Self::respond_local()`]. - /// - /// # Lifetimes - /// - /// The producer can borrow from the request message (`'req`). Note that - /// it cannot borrow from the response message. - type Producer: ProduceMessage; - /// Respond to a DNS request. /// - /// The provided consumer must have been provided the entire DNS request - /// message. This method will use the extracted information to formulate - /// a response message, in the form of a producer type. - /// - /// The returned future does not implement [`Send`]. Use [`Service`] and - /// [`Service::respond()`] for a [`Send`]-implementing version. - #[allow(async_fn_in_trait)] - async fn respond_local( + /// The returned [`Future`] is thread-local; it does not implement + /// [`Send`]. Use [`Service::respond()`] for a thread-safe alternative. + fn respond_local( &self, - request: &RequestMessage<'_>, - ) -> Self::Producer; + exchange: &mut Exchange<'_>, + ) -> impl Future; } //----------- ServiceLayer --------------------------------------------------- @@ -113,25 +89,26 @@ pub trait LocalService { /// /// Layers can be combined (usually in a tuple) into larger layers. A number /// of blanket implementations are provided to simplify this. -pub trait ServiceLayer: - LocalServiceLayer + Sync -{ - /// Respond to a DNS request. +pub trait ServiceLayer: LocalServiceLayer + Sync { + /// Process an incoming DNS request. /// - /// The provided consumer must have been provided the entire DNS request - /// message. If the request should be forwarded through to the wrapped - /// service, [`ControlFlow::Continue`] is returned, with a transformer to - /// modify the wrapped service's response. However, if the request should - /// be responded to directly by this layer, without any interaction from - /// the wrapped service, [`ControlFlow::Break`] is returned. + /// The returned [`Future`] is thread-safe; it implements [`Send`]. Use + /// [`LocalServiceLayer::process_local_incoming()`] if this is not + /// necessary. + fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> impl Future> + Send; + + /// Process an outgoing DNS response. /// - /// The returned future implements [`Send`]. Use [`LocalServiceLayer`] - /// and [`LocalServiceLayer::respond_local()`] if [`Send`] is not + /// The returned [`Future`] is thread-safe; it implements [`Send`]. Use + /// [`LocalServiceLayer::process_local_outgoing()`] if this is not /// necessary. - fn respond( + fn process_outgoing( &self, - request: &RequestMessage<'_>, - ) -> impl Future> + Send; + response: OutgoingResponse<'_, '_>, + ) -> impl Future + Send; } //----------- LocalServiceLayer ---------------------------------------------- @@ -148,264 +125,23 @@ pub trait ServiceLayer: /// Layers can be combined (usually in a tuple) into larger layers. A number /// of blanket implementations are provided to simplify this. pub trait LocalServiceLayer { - /// A producer of DNS responses. - /// - /// This type returns components to insert in a DNS response message. It - /// is constructed by [`Self::respond_local()`], if a response is returned - /// early (without the wrapped service interacting with it). - /// - /// # Lifetimes - /// - /// The producer can borrow from the request message (`'req`). Note that - /// it cannot borrow from the response message. - type Producer: ProduceMessage; - - /// A transformer of DNS responses. - /// - /// This type modifies the response from the wrapped service, by adding, - /// removing, or modifying the components of the response message. It is - /// constructed by [`Self::respond_local()`], if an early return does not - /// occur. - /// - /// # Lifetimes - /// - /// The transformer can borrow from the request message (`'req`). Note - /// that it cannot borrow from the response message. - type Transformer: TransformMessage; - - /// Respond to a DNS request. - /// - /// The provided consumer must have been provided the entire DNS request - /// message. If the request should be forwarded through to the wrapped - /// service, [`ControlFlow::Continue`] is returned, with a transformer to - /// modify the wrapped service's response. However, if the request should - /// be responded to directly by this layer, without any interaction from - /// the wrapped service, [`ControlFlow::Break`] is returned. + /// Process an incoming DNS request. /// - /// The returned future does not implement [`Send`]. Use [`ServiceLayer`] - /// and [`ServiceLayer::respond()`] for a [`Send`]-implementing version. - #[allow(async_fn_in_trait)] - async fn respond_local( + /// The returned [`Future`] is thread-local; it does not implement + /// [`Send`]. Use [`ServiceLayer::process_incoming()`] for a thread-safe + /// alternative. + fn process_local_incoming( &self, - request: &RequestMessage<'_>, - ) -> ControlFlow; -} - -//----------- ProduceMessage ------------------------------------------------- + exchange: &mut Exchange<'_>, + ) -> impl Future>; -/// A type that produces a DNS message. -/// -/// This interface is similar to [`Iterator`], except that it can iterate over -/// the different components of a message (questions, answers, authorities, -/// and additional records). -/// -/// # Architecture -/// -/// This interface is convenient when multiple transformers need to modify the -/// message as it is being built. Rather than forcing each transformer to -/// parse and rewrite the message, this interface allows the message to built -/// up over a single iteration, with every transformer directly examining each -/// component added to the message. -/// -/// # Examples -pub trait ProduceMessage { - /// The header of the message. - /// - /// The provided header will be uninitialized, and this method is expected - /// to reset it entirely. The default implementation does nothing. - fn header(&mut self, header: &mut Header) { - let _ = header; - } - - /// The next DNS question in the message. - /// - /// This method is expected to add at most one question using the given - /// message builder. If a question is added, its builder is returned so - /// that it can be modified or filtered before being finalized. - /// - /// This must act like a fused iterator; if no question is added, then - /// future calls to the same method will also add no questions. - /// - /// The default implementation of this method will add no questions. - /// - /// # Errors - /// - /// If new records cannot be inserted in the message (because it is full), - /// the method is responsible for returning gracefully. The message may - /// be marked as truncated, for example. - fn next_question<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - let _ = builder; - None - } - - /// The next answer record in the message. - /// - /// This method is expected to add at most one answer record using the - /// given message builder. If a record is added, its builder is returned - /// so that it can be modified or filtered before being finalized. - /// - /// This must act like a fused iterator; if no record is added, then - /// future calls to the same method will also add no records. + /// Process an outgoing DNS response. /// - /// The default implementation of this method will add no records. - /// - /// # Errors - /// - /// If new records cannot be inserted in the message (because it is full), - /// the method is responsible for returning gracefully. The message may - /// be marked as truncated, for example. - fn next_answer<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - let _ = builder; - None - } - - /// The next authority record in the message. - /// - /// This method is expected to add at most one authority record using the - /// given message builder. If a record is added, its builder is returned - /// so that it can be modified or filtered before being finalized. - /// - /// This must act like a fused iterator; if no record is added, then - /// future calls to the same method will also add no records. - /// - /// The default implementation of this method will add no records. - /// - /// # Errors - /// - /// If new records cannot be inserted in the message (because it is full), - /// the method is responsible for returning gracefully. The message may - /// be marked as truncated, for example. - fn next_authority<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - let _ = builder; - None - } - - /// The next additional record in the message. - /// - /// This method is expected to add at most one additional record using the - /// given message builder. If a record is added, its builder is returned - /// so that it can be modified or filtered before being finalized. - /// - /// This must act like a fused iterator; if no record is added, then - /// future calls to the same method will also add no records. - /// - /// The default implementation of this method will add no records. - /// - /// # Errors - /// - /// If new records cannot be inserted in the message (because it is full), - /// the method is responsible for returning gracefully. The message may - /// be marked as truncated, for example. - fn next_additional<'b>( - &mut self, - builder: &'b mut MessageBuilder<'_>, - ) -> Option> { - let _ = builder; - None - } -} - -//----------- TransformMessage ----------------------------------------------- - -/// A type that modifies a DNS message as it is being built. -/// -/// This interface is designed around [`ProduceMessage`]: as the components of -/// the message are produced, they are passed through methods of this trait to -/// be modified or filtered out. Furthermore, implementing types can add more -/// components to the message as they also implement [`ProduceMessage`]. -/// -/// # Examples -pub trait TransformMessage: ProduceMessage { - /// Modify the header of the message. - /// - /// The provided header has been initialized; this method can choose to - /// modify it. The default implementation does nothing. - fn modify_header(&mut self, header: &mut Header) { - let _ = header; - } - - /// Modify a question added to the message. - /// - /// This method is called when a question is being added to the message. - /// The question can be modified. - /// - /// If [`ControlFlow::Continue`] is returned, the question is preserved, - /// and can be passed through future transformations. Otherwise, the - /// question is removed. - /// - /// The default implementation of this method passes the question through - /// transparently. - fn modify_question( - &mut self, - builder: &mut QuestionBuilder<'_>, - ) -> ControlFlow<()> { - let _ = builder; - ControlFlow::Continue(()) - } - - /// Modify an answer record added to the message. - /// - /// This method is called when an answer record is being added to the - /// message. The record (and its data) can be modified. - /// - /// If [`ControlFlow::Continue`] is returned, the record is preserved, - /// and can be passed through future transformations. Otherwise, the - /// record is removed. - /// - /// The default implementation of this method passes the record through - /// transparently. - fn modify_answer( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - let _ = builder; - ControlFlow::Continue(()) - } - - /// Modify an authority record added to the message. - /// - /// This method is called when an authority record is being added to the - /// message. The record (and its data) can be modified. - /// - /// If [`ControlFlow::Continue`] is returned, the record is preserved, - /// and can be passed through future transformations. Otherwise, the - /// record is removed. - /// - /// The default implementation of this method passes the record through - /// transparently. - fn modify_authority( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - let _ = builder; - ControlFlow::Continue(()) - } - - /// Modify an additional record added to the message. - /// - /// This method is called when an additional record is being added to the - /// message. The record (and its data) can be modified. - /// - /// If [`ControlFlow::Continue`] is returned, the record is preserved, - /// and can be passed through future transformations. Otherwise, the - /// record is removed. - /// - /// The default implementation of this method passes the record through - /// transparently. - fn modify_additional( - &mut self, - builder: &mut RecordBuilder<'_>, - ) -> ControlFlow<()> { - let _ = builder; - ControlFlow::Continue(()) - } + /// The returned [`Future`] is thread-local; it does not implement + /// [`Send`]. Use [`ServiceLayer::process_outgoing()`] for a thread-safe + /// alternative. + fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) -> impl Future; } diff --git a/src/new_server/request.rs b/src/new_server/request.rs deleted file mode 100644 index 8a0eadc48..000000000 --- a/src/new_server/request.rs +++ /dev/null @@ -1,591 +0,0 @@ -//! DNS request messages. - -use core::{iter::FusedIterator, marker::PhantomData, ops::Range}; - -use crate::{ - new_base::{ - name::{Name, UnparsedName}, - parse::{ParseMessageBytes, SplitMessageBytes}, - wire::{AsBytes, ParseBytes, ParseError, SizePrefixed, U16}, - Message, ParseRecordData, QClass, QType, Question, RClass, RType, - Record, SectionCounts, UnparsedRecordData, TTL, - }, - new_edns::{EdnsOption, EdnsRecord}, - new_rdata::EdnsOptionsIter, -}; - -/// A DNS request message. -#[derive(Clone)] -pub struct RequestMessage<'b> { - /// The underlying [`Message`]. - pub message: &'b Message, - - /// Cached offsets for the question section. - questions: (Range, [Range; 1]), - - /// Cached offsets for the answer section. - answers: (Range, [Range; 0]), - - /// Cached offsets for the authority section. - authorities: (Range, [Range; 0]), - - /// Cached offsets for the additional section. - additional: (Range, [Range; 2]), - - /// Cached offsets for the EDNS record. - edns: (Range, u16, [Range; 4]), -} - -//--- Construction - -impl<'b> RequestMessage<'b> { - /// Wrap a raw [`Message`] into a [`RequestMessage`]. - /// - /// This will iterate through the message, pre-filling some caches for - /// efficient access in the future. - pub fn new(message: &'b Message) -> Result { - /// Parse the question section into cached offsets. - fn parse_questions( - contents: &[u8], - range: &mut Range, - number: u16, - indices: &mut [Range], - ) -> Result<(), ParseError> { - let mut indices = indices.iter_mut(); - let mut offset = range.start as usize; - - for _ in 0..number { - let (question, rest) = - Question::<&UnparsedName>::split_message_bytes( - contents, offset, - )?; - - if let Some(indices) = indices.next() { - let fields = offset + question.qname.len(); - *indices = offset as u16..fields as u16; - } - - offset = rest; - } - - range.end = offset as u16; - Ok(()) - } - - /// Parse a record section into cached offsets. - fn parse_records( - contents: &[u8], - section: u8, - range: &mut Range, - number: u16, - indices: &mut [Range], - edns_range: &mut Option>, - ) -> Result<(), ParseError> { - let mut indices = indices.iter_mut(); - let mut offset = range.start as usize; - - for _ in 0..number { - let (record, rest) = Record::< - &UnparsedName, - &UnparsedRecordData, - >::split_message_bytes( - contents, offset - )?; - - if let Some(indices) = indices.next() { - let fields = offset + record.rname.len(); - *indices = offset as u16..fields as u16; - } - - if section == 3 && record.rtype == RType::OPT { - if edns_range.is_some() { - // A DNS message can only contain one EDNS record. - return Err(ParseError); - } - - *edns_range = Some(offset as u16..rest as u16); - } - - offset = rest; - } - - range.end = offset as u16; - Ok(()) - } - - /// Parse the EDNS record into cached offsets. - fn parse_edns( - contents: &[u8], - range: Range, - number: &mut u16, - indices: &mut [Range], - ) -> Result<(), ParseError> { - let mut indices = indices.iter_mut(); - let mut offset = range.start as usize + 11; - - while offset < range.end as usize { - let (_type, rest) = - <&U16>::split_message_bytes(contents, offset)?; - let (_data, rest) = - >::split_message_bytes( - contents, rest, - )?; - - *number += 1; - - if let Some(indices) = indices.next() { - *indices = offset as u16..rest as u16; - } - - offset = rest; - } - - Ok(()) - } - - // DNS messages are 64KiB at the largest. - let _ = u16::try_from(message.as_bytes().len()) - .map_err(|_| ParseError)?; - - let mut this = Self { - message, - questions: Default::default(), - answers: Default::default(), - authorities: Default::default(), - additional: Default::default(), - edns: Default::default(), - }; - - let mut edns_range = None; - - parse_questions( - &message.contents, - &mut this.questions.0, - message.header.counts.questions.get(), - &mut this.questions.1, - )?; - - this.answers.0 = this.questions.0.end..0; - parse_records( - &message.contents, - 1, - &mut this.answers.0, - message.header.counts.answers.get(), - &mut this.answers.1, - &mut edns_range, - )?; - - this.authorities.0 = this.answers.0.end..0; - parse_records( - &message.contents, - 2, - &mut this.authorities.0, - message.header.counts.authorities.get(), - &mut this.authorities.1, - &mut edns_range, - )?; - - this.additional.0 = this.authorities.0.end..0; - parse_records( - &message.contents, - 2, - &mut this.additional.0, - message.header.counts.additional.get(), - &mut this.additional.1, - &mut edns_range, - )?; - - if let Some(edns_range) = edns_range { - this.edns.0 = edns_range.clone(); - parse_edns( - &message.contents, - edns_range, - &mut this.edns.1, - &mut this.edns.2, - )?; - } - - Ok(this) - } -} - -//--- Internals - -impl<'b> RequestMessage<'b> { - /// The section counts. - fn counts(&self) -> &'b SectionCounts { - &self.message.header.counts - } -} - -//--- Inspection - -impl<'b> RequestMessage<'b> { - /// The sole question in the message. - /// - /// # Name Compression - /// - /// Due to the restrictions around compressed domain names (in order to - /// prevent attackers from crafting compression pointer loops), it is - /// guaranteed that the first QNAME in the message is uncompressed. - /// - /// # Errors - /// - /// Fails if there are zero or more than one question in the message. - pub fn sole_question(&self) -> Result, ParseError> { - if self.message.header.counts.questions.get() != 1 { - return Err(ParseError); - } - - // SAFETY: 'RequestMessage' is pre-validated. - let range = self.questions.1[0].clone(); - let range = range.start as usize..range.end as usize; - let qname = &self.message.contents[range.clone()]; - let qname = unsafe { Name::from_bytes_unchecked(qname) }; - let fields = &self.message.contents[range.end..]; - let qtype = QType::parse_bytes(&fields[0..2]).unwrap(); - let qclass = QClass::parse_bytes(&fields[2..4]).unwrap(); - - Ok(Question { - qname, - qtype, - qclass, - }) - } - - /// The EDNS record in the message, if any. - pub fn edns_record(&self) -> Option> { - if self.edns.0.is_empty() { - return None; - } - - let range = self.edns.0.clone(); - let contents = &self.message.contents[..range.end as usize]; - EdnsRecord::parse_message_bytes(contents, range.start as usize) - .map(Some) - .expect("'RequestMessage' only holds well-formed EDNS records") - } - - /// The questions in the message. - /// - /// An iterator of [`Question`]s is returned. - /// - /// # Name Compression - /// - /// The returned questions use [`UnparsedName`] for the QNAMEs. These can - /// be resolved against the original message to determine the whole domain - /// name, if necessary. Note that decompression can fail. - pub fn questions(&self) -> RequestQuestions<'_, 'b> { - let contents = self.questions.0.clone(); - RequestQuestions { - message: self, - cache: self.questions.1.iter(), - contents: contents.start as usize..contents.end as usize, - indices: 0..self.counts().questions.get(), - } - } - - /// The answer records in the message. - /// - /// An iterator of [`Record`]s is returned. - /// - /// # Name Compression - /// - /// The returned records use [`UnparsedName`] for the RNAMEs. These can - /// be resolved against the original message to determine the whole domain - /// name, if necessary. Note that decompression can fail. - /// - /// # Record Data - /// - /// The caller can select an appropriate record data type to use. In most - /// cases, [`RecordData`](crate::new_rdata::RecordData) is appropriate; if - /// many records will be skipped, however, [`UnparsedRecordData`] might be - /// preferable. - pub fn answers(&self) -> RequestRecords<'_, 'b, D> - where - D: ParseRecordData<'b>, - { - let contents = self.answers.0.clone(); - RequestRecords { - message: self, - cache: self.answers.1.iter(), - contents: contents.start as usize..contents.end as usize, - indices: 0..self.counts().answers.get(), - _rdata: PhantomData, - } - } - - /// The authority records in the message. - /// - /// An iterator of [`Record`]s is returned. - /// - /// # Name Compression - /// - /// The returned records use [`UnparsedName`] for the RNAMEs. These can - /// be resolved against the original message to determine the whole domain - /// name, if necessary. Note that decompression can fail. - /// - /// # Record Data - /// - /// The caller can select an appropriate record data type to use. In most - /// cases, [`RecordData`](crate::new_rdata::RecordData) is appropriate; if - /// many records will be skipped, however, [`UnparsedRecordData`] might be - /// preferable. - pub fn authorities(&self) -> RequestRecords<'_, 'b, D> - where - D: ParseRecordData<'b>, - { - let contents = self.authorities.0.clone(); - RequestRecords { - message: self, - cache: self.authorities.1.iter(), - contents: contents.start as usize..contents.end as usize, - indices: 0..self.counts().authorities.get(), - _rdata: PhantomData, - } - } - - /// The additional records in the message. - /// - /// An iterator of [`Record`]s is returned. - /// - /// # Name Compression - /// - /// The returned records use [`UnparsedName`] for the RNAMEs. These can - /// be resolved against the original message to determine the whole domain - /// name, if necessary. Note that decompression can fail. - /// - /// # Record Data - /// - /// The caller can select an appropriate record data type to use. In most - /// cases, [`RecordData`](crate::new_rdata::RecordData) is appropriate; if - /// many records will be skipped, however, [`UnparsedRecordData`] might be - /// preferable. - pub fn additional(&self) -> RequestRecords<'_, 'b, D> - where - D: ParseRecordData<'b>, - { - let contents = self.additional.0.clone(); - RequestRecords { - message: self, - cache: self.additional.1.iter(), - contents: contents.start as usize..contents.end as usize, - indices: 0..self.counts().additional.get(), - _rdata: PhantomData, - } - } - - /// The EDNS options in the message. - /// - /// An iterator of (results of) [`EdnsOption`]s is returned. - /// - /// If the message does not contain an EDNS record, this is empty. - /// - /// # Errors - /// - /// While the overall structure of the EDNS record is verified when the - /// [`RequestMessage`] is constructed, option-specific validation (e.g. - /// checking the size of an EDNS cookie option) is not performed early. - /// Those errors will be caught and returned by the iterator. - pub fn edns_options(&self) -> RequestEdnsOptions<'b> { - let start = self.edns.0.start as usize + 11; - let end = self.edns.0.end as usize; - let options = &self.message.contents[start..end]; - RequestEdnsOptions { - inner: EdnsOptionsIter::new(options), - indices: 0..self.edns.1, - } - } -} - -//----------- RequestQuestions ----------------------------------------------- - -/// The questions in a [`RequestMessage`]. -#[derive(Clone)] -pub struct RequestQuestions<'r, 'b> { - /// The underlying request message. - message: &'r RequestMessage<'b>, - - /// The cached question ranges. - cache: core::slice::Iter<'r, Range>, - - /// The range of message contents to parse. - contents: Range, - - /// The range of record indices left. - indices: Range, -} - -impl<'b> Iterator for RequestQuestions<'_, 'b> { - type Item = Question<&'b UnparsedName>; - - fn next(&mut self) -> Option { - // Try loading a cached question. - if let Some(range) = self.cache.next().cloned() { - if range.is_empty() { - // There are no more questions, stop. - self.cache = Default::default(); - self.contents.start = self.contents.end; - return None; - } - - // SAFETY: 'RequestMessage' is pre-validated. - let range = range.start as usize..range.end as usize; - let qname = &self.message.message.contents[range.clone()]; - let qname = unsafe { UnparsedName::from_bytes_unchecked(qname) }; - let fields = &self.message.message.contents[range.end..]; - let qtype = QType::parse_bytes(&fields[0..2]).unwrap(); - let qclass = QClass::parse_bytes(&fields[2..4]).unwrap(); - - self.indices.start += 1; - return Some(Question { - qname, - qtype, - qclass, - }); - } - - let _ = self.indices.next()?; - let contents = &self.message.message.contents[..self.contents.end]; - let (question, rest) = - Question::split_message_bytes(contents, self.contents.start) - .expect("'RequestMessage' only contains valid questions"); - - self.contents.start = rest; - Some(question) - } -} - -impl ExactSizeIterator for RequestQuestions<'_, '_> { - fn len(&self) -> usize { - self.indices.len() - } -} - -impl FusedIterator for RequestQuestions<'_, '_> {} - -//----------- RequestRecords ------------------------------------------------- - -/// The records in a section of a [`RequestMessage`]. -#[derive(Clone)] -pub struct RequestRecords<'r, 'b, D> { - /// The underlying request message. - message: &'r RequestMessage<'b>, - - /// The cached record ranges. - cache: core::slice::Iter<'r, Range>, - - /// The range of message contents to parse. - contents: Range, - - /// The range of record indices left. - indices: Range, - - /// A representation of the record data held. - _rdata: PhantomData<&'r [D]>, -} - -impl<'b, D> Iterator for RequestRecords<'_, 'b, D> -where - D: ParseRecordData<'b>, -{ - type Item = Result, ParseError>; - - fn next(&mut self) -> Option { - // Try loading a cached record. - if let Some(range) = self.cache.next().cloned() { - if range.is_empty() { - // There are no more records, stop. - self.cache = Default::default(); - self.contents.start = self.contents.end; - return None; - } - - // SAFETY: 'RequestMessage' is pre-validated. - let range = range.start as usize..range.end as usize; - let rname = &self.message.message.contents[range.clone()]; - let rname = unsafe { UnparsedName::from_bytes_unchecked(rname) }; - let fields = &self.message.message.contents[range.end..]; - let rtype = RType::parse_bytes(&fields[0..2]).unwrap(); - let rclass = RClass::parse_bytes(&fields[2..4]).unwrap(); - let ttl = TTL::parse_bytes(&fields[4..8]).unwrap(); - let size = U16::parse_bytes(&fields[8..10]).unwrap(); - let rdata_end = range.end + 10 + size.get() as usize; - let rdata = &self.message.message.contents[..rdata_end]; - let rdata = - match D::parse_record_data(rdata, range.end + 10, rtype) { - Ok(rdata) => rdata, - Err(err) => return Some(Err(err)), - }; - - self.indices.start += 1; - return Some(Ok(Record { - rname, - rtype, - rclass, - ttl, - rdata, - })); - } - - let _ = self.indices.next()?; - let contents = &self.message.message.contents[..self.contents.end]; - let (record, rest) = match Record::split_message_bytes( - contents, - self.contents.start, - ) { - Ok((record, rest)) => (record, rest), - Err(err) => return Some(Err(err)), - }; - - self.contents.start = rest; - Some(Ok(record)) - } -} - -impl<'b, D> ExactSizeIterator for RequestRecords<'_, 'b, D> -where - D: ParseRecordData<'b>, -{ - fn len(&self) -> usize { - self.indices.len() - } -} - -impl<'b, D> FusedIterator for RequestRecords<'_, 'b, D> where - D: ParseRecordData<'b> -{ -} - -//----------- RequestEdnsOptions --------------------------------------------- - -/// The EDNS options in a [`RequestMessage`]. -/// -/// This is a wrapper around [`EdnsOptionsIter`] that also implements -/// [`ExactSizeIterator`], as the total number of EDNS options is cached in -/// the [`RequestMessage`]. -#[derive(Clone)] -pub struct RequestEdnsOptions<'b> { - /// The underlying iterator. - inner: EdnsOptionsIter<'b>, - - /// The range of option indices left. - indices: Range, -} - -impl<'b> Iterator for RequestEdnsOptions<'b> { - type Item = Result, ParseError>; - - fn next(&mut self) -> Option { - let _ = self.indices.next()?; - self.inner.next() - } -} - -impl ExactSizeIterator for RequestEdnsOptions<'_> { - fn len(&self) -> usize { - self.indices.len() - } -} - -impl FusedIterator for RequestEdnsOptions<'_> {} diff --git a/src/new_server/transport/mod.rs b/src/new_server/transport/mod.rs index e7aa9350b..4bb193633 100644 --- a/src/new_server/transport/mod.rs +++ b/src/new_server/transport/mod.rs @@ -1,17 +1,20 @@ //! Network transports for DNS servers. use core::net::SocketAddr; -use std::{io, sync::Arc, vec::Vec}; +use std::{io, sync::Arc, time::SystemTime, vec::Vec}; +use bumpalo::Bump; use tokio::net::UdpSocket; -use crate::new_base::{ - build::{BuilderContext, MessageBuilder}, - wire::{AsBytes, ParseBytesByRef}, - Message, +use crate::{ + new_base::{ + wire::{AsBytes, ParseBytesByRef}, + Message, + }, + new_server::exchange::Allocator, }; -use super::{ProduceMessage, RequestMessage, Service}; +use super::{exchange::ParsedMessage, Exchange, Service}; //----------- serve_udp() ---------------------------------------------------- @@ -36,17 +39,17 @@ pub async fn serve_udp( impl State { /// Respond to a particular UDP request. - async fn respond( - self: Arc, - mut buffer: Vec, - peer: SocketAddr, - ) { + async fn respond(self: Arc, buffer: Vec, peer: SocketAddr) { let Ok(message) = Message::parse_bytes_by_ref(&buffer) else { // This message is fundamentally invalid, just give up. return; }; - let Ok(request) = RequestMessage::new(message) else { + let mut allocator = Bump::new(); + let mut allocator = Allocator::new(&mut allocator); + + let Ok(request) = ParsedMessage::parse(message, &mut allocator) + else { // This message is malformed; inform the peer and stop. let mut buffer = [0u8; 12]; let response = Message::parse_bytes_by_mut(&mut buffer) @@ -58,38 +61,27 @@ pub async fn serve_udp( return; }; + // Build a complete 'Exchange' around the request. + let mut exchange = Exchange { + alloc: allocator, + reception: SystemTime::now(), + request, + response: ParsedMessage::default(), + metadata: Vec::new(), + }; + // Generate the appropriate response. - let mut producer = self.service.respond(&request).await; + self.service.respond(&mut exchange).await; // Build up the response message. - - buffer.clear(); - buffer.resize(65536, 0); - let mut context = BuilderContext::default(); - let mut builder = MessageBuilder::new(&mut buffer, &mut context); - - producer.header(builder.header_mut()); - while let Some(question) = producer.next_question(&mut builder) { - question.commit(); - } - while let Some(answer) = producer.next_answer(&mut builder) { - answer.commit(); - } - while let Some(authority) = producer.next_authority(&mut builder) - { - authority.commit(); - } - while let Some(additional) = - producer.next_additional(&mut builder) - { - additional.commit(); - } + let mut buffer = vec![0u8; 65536]; + let message = + exchange.response.build(&mut buffer).unwrap_or_else(|_| { + todo!("how to handle truncation errors?") + }); // Send the response back to the peer. - let _ = self - .socket - .send_to(builder.message().as_bytes(), peer) - .await; + let _ = self.socket.send_to(message.as_bytes(), peer).await; } } From feb68af9e1514cddcf96b0f7fc46724d5c447cfd Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 3 Feb 2025 11:00:25 +0100 Subject: [PATCH 126/167] [new_server/exchange] Make 'new()' non-const --- src/new_server/exchange.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/new_server/exchange.rs b/src/new_server/exchange.rs index 3b303a360..481b099cc 100644 --- a/src/new_server/exchange.rs +++ b/src/new_server/exchange.rs @@ -424,7 +424,7 @@ pub struct Allocator<'a> { impl<'a> Allocator<'a> { /// Construct a new [`Allocator`]. - pub const fn new(inner: &'a mut Bump) -> Self { + pub fn new(inner: &'a mut Bump) -> Self { // NOTE: The 'Bump' is mutably borrowed for lifetime 'a; the reference // we store is thus guaranteed to be unique. Self { inner } From 4f9717608f529bfe4675563af3f53dc6a4602861 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 3 Feb 2025 11:32:14 +0100 Subject: [PATCH 127/167] [new_server/exchange] Enhance documentation --- src/new_server/exchange.rs | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/src/new_server/exchange.rs b/src/new_server/exchange.rs index 481b099cc..11e33ec58 100644 --- a/src/new_server/exchange.rs +++ b/src/new_server/exchange.rs @@ -1,4 +1,9 @@ //! Request-response exchanges for DNS servers. +//! +//! This module provides a number of utility types for the DNS service layer +//! architecture. In particular, an [`Exchange`] represents a DNS request as +//! it is being passed along a server pipeline, and an [`OutgoingResponse`] is +//! the corresponding response as it is passed back through. use core::any::{Any, TypeId}; use std::{boxed::Box, time::SystemTime, vec::Vec}; @@ -244,7 +249,8 @@ impl<'a> ParsedMessage<'a> { /// Build this message into the given buffer. /// - /// If the message was too large, a [`TruncationError`] is returned. + /// If the message could not fit in the given buffer, a + /// [`TruncationError`] is returned. pub fn build<'b>( &self, buffer: &'b mut [u8], @@ -341,8 +347,18 @@ impl ParsedMessage<'_> { /// Arbitrary metadata about a DNS exchange. /// +/// This should be used by [`ServiceLayer`](super::ServiceLayer)s for storing +/// information they have extracted from an incoming DNS request message. The +/// metadata may be relevant to future layers: for example, some may wish to +/// handle TSIG-signed requests differently from others. The metadata is also +/// relevant to the original layer in [`process_outgoing()`], as it does not +/// have access to the original request. +/// +/// # Implementation +/// /// This is an enhanced version of `Box` that can -/// perform downcasting more efficiently. +/// perform downcasting more efficiently. It stores the [`TypeId`] of the +/// object inline, allowing it to skip a vtable lookup. pub struct Metadata { /// The type ID of the object. type_id: TypeId, @@ -408,6 +424,18 @@ impl Metadata { /// A bump allocator with a fixed lifetime. /// /// This is a wrapper around [`bumpalo::Bump`] that guarantees thread safety. +/// It is equivalent to `&'a mut Bump`, but `&mut &'a mut Bump` does not work +/// (allocated objects only last for the shorter lifetime, not for `'a`). +/// `&mut Allocator<'a>` does work, giving objects of lifetime `'a`. +/// +/// # Thread Safety +/// +/// [`Bump`] is not thread safe; using it from multiple threads simultaneously +/// would cause undefined behaviour. [`Allocator`] implements [`Send`], and +/// so it cannot directly expose shared references to the underlying [`Bump`]; +/// a user could get `&Bump` on one thread, send the [`Allocator`] to another +/// thread, then get `&Bump` over there. This is why [`Allocator`] copies +/// [`Bump`]'s methods instead of implementing [`Deref`] to [`Bump`]. #[derive(Debug)] #[repr(transparent)] pub struct Allocator<'a> { From cdf18b7663d1a6a659b875cd58e6818948953864 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 4 Feb 2025 09:44:50 +0100 Subject: [PATCH 128/167] [new_server/exchange] Fix broken doc links --- src/new_server/exchange.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/new_server/exchange.rs b/src/new_server/exchange.rs index 11e33ec58..20b1621f5 100644 --- a/src/new_server/exchange.rs +++ b/src/new_server/exchange.rs @@ -354,6 +354,8 @@ impl ParsedMessage<'_> { /// relevant to the original layer in [`process_outgoing()`], as it does not /// have access to the original request. /// +/// [`process_outgoing()`]: super::ServiceLayer::process_outgoing() +/// /// # Implementation /// /// This is an enhanced version of `Box` that can @@ -436,6 +438,8 @@ impl Metadata { /// a user could get `&Bump` on one thread, send the [`Allocator`] to another /// thread, then get `&Bump` over there. This is why [`Allocator`] copies /// [`Bump`]'s methods instead of implementing [`Deref`] to [`Bump`]. +/// +/// [`Deref`]: core::ops::Deref #[derive(Debug)] #[repr(transparent)] pub struct Allocator<'a> { From 0cd88abf08f64a54c848648e6105a417d8df54b1 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 4 Feb 2025 18:37:12 +0100 Subject: [PATCH 129/167] [new_base] Fix various minor bugs --- src/new_base/build/message.rs | 2 +- src/new_base/name/reversed.rs | 36 +++++++++++++------------ src/new_base/parse/mod.rs | 2 +- src/new_base/question.rs | 49 +++++++++++++++++++++++++++++++++-- src/new_base/record.rs | 48 ++++++++++++++++++++++++++++++---- src/new_base/wire/ints.rs | 6 +++++ src/new_edns/mod.rs | 1 + 7 files changed, 118 insertions(+), 26 deletions(-) diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 6f5311383..0dc2bdd95 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -89,7 +89,7 @@ impl MessageBuilder<'_, '_> { impl<'b> MessageBuilder<'b, '_> { /// End the builder, returning the built message. pub fn finish(self) -> &'b Message { - self.message + self.message.slice_to(self.context.size) } /// Reborrow the builder with a shorter lifetime. diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 551082a31..2edf9a75b 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -195,24 +195,20 @@ impl Hash for RevName { impl fmt::Debug for RevName { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - struct RevLabels<'a>(&'a RevName); - - impl fmt::Debug for RevLabels<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut first = true; - self.0.labels().try_for_each(|label| { - if !first { - f.write_str(".")?; - } else { - first = false; - } - - label.fmt(f) - }) + f.write_str("RevName(")?; + + let mut first = true; + self.labels().try_for_each(|label| { + if !first { + f.write_str(".")?; + } else { + first = false; } - } - f.debug_tuple("RevName").field(&RevLabels(self)).finish() + fmt::Display::fmt(&label, f) + })?; + + f.write_str(")") } } @@ -482,7 +478,7 @@ impl AsMut for RevNameBuf { } } -//--- Forwarding equality, comparison, and hashing +//--- Forwarding equality, comparison, hashing, and formatting impl PartialEq for RevNameBuf { fn eq(&self, that: &Self) -> bool { @@ -509,3 +505,9 @@ impl Hash for RevNameBuf { (**self).hash(state) } } + +impl fmt::Debug for RevNameBuf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 03ac16995..5a16fd17f 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -54,7 +54,7 @@ impl<'a, T: ?Sized + SplitBytesByRef> SplitMessageBytes<'a> for &'a T { start: usize, ) -> Result<(Self, usize), ParseError> { T::split_bytes_by_ref(&contents[start..]) - .map(|(this, rest)| (this, contents.len() - start - rest.len())) + .map(|(this, rest)| (this, contents.len() - rest.len())) } } diff --git a/src/new_base/question.rs b/src/new_base/question.rs index dbcc2e3c9..c7a5203ce 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -1,5 +1,7 @@ //! DNS questions. +use core::fmt; + use domain_macros::*; use super::{ @@ -40,6 +42,19 @@ impl Question { } } +//--- Interaction + +impl Question { + /// Map the name in this question to another type. + pub fn map_name R>(self, f: F) -> Question { + Question { + qname: (f)(self.qname), + qtype: self.qtype, + qclass: self.qclass, + } + } +} + //--- Parsing from DNS messages impl<'a, N> SplitMessageBytes<'a> for Question @@ -96,7 +111,6 @@ where #[derive( Copy, Clone, - Debug, PartialEq, Eq, PartialOrd, @@ -155,13 +169,32 @@ impl QType { pub const AAAA: Self = Self::new(28); } +//--- Formatting + +impl fmt::Debug for QType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match *self { + Self::A => "QType::A", + Self::NS => "QType::NS", + Self::CNAME => "QType::CNAME", + Self::SOA => "QType::SOA", + Self::WKS => "QType::WKS", + Self::PTR => "QType::PTR", + Self::HINFO => "QType::HINFO", + Self::MX => "QType::MX", + Self::TXT => "QType::TXT", + Self::AAAA => "QType::AAAA", + _ => return write!(f, "QType({})", self.code), + }) + } +} + //----------- QClass --------------------------------------------------------- /// The class of a question. #[derive( Copy, Clone, - Debug, PartialEq, Eq, PartialOrd, @@ -195,3 +228,15 @@ impl QClass { /// The CHAOS class. pub const CH: Self = Self::new(3); } + +//--- Formatting + +impl fmt::Debug for QClass { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match *self { + Self::IN => "QClass::IN", + Self::CH => "QClass::CH", + _ => return write!(f, "QClass({})", self.code), + }) + } +} diff --git a/src/new_base/record.rs b/src/new_base/record.rs index eef4c2840..85eefd732 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -1,6 +1,6 @@ //! DNS records. -use core::{borrow::Borrow, ops::Deref}; +use core::{borrow::Borrow, fmt, ops::Deref}; use super::{ build::{self, BuildIntoMessage, BuildResult}, @@ -76,7 +76,7 @@ where let (_, rest) = <&SizePrefixed<[u8]>>::split_message_bytes(contents, rest)?; let rdata = - D::parse_record_data(&contents[..rest], rdata_start, rtype)?; + D::parse_record_data(&contents[..rest], rdata_start + 2, rtype)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) } @@ -178,7 +178,6 @@ where #[derive( Copy, Clone, - Debug, PartialEq, Eq, PartialOrd, @@ -240,13 +239,33 @@ impl RType { pub const OPT: Self = Self::new(41); } +//--- Formatting + +impl fmt::Debug for RType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match *self { + Self::A => "RType::A", + Self::NS => "RType::NS", + Self::CNAME => "RType::CNAME", + Self::SOA => "RType::SOA", + Self::WKS => "RType::WKS", + Self::PTR => "RType::PTR", + Self::HINFO => "RType::HINFO", + Self::MX => "RType::MX", + Self::TXT => "RType::TXT", + Self::AAAA => "RType::AAAA", + Self::OPT => "RType::OPT", + _ => return write!(f, "RType({})", self.code), + }) + } +} + //----------- RClass --------------------------------------------------------- /// The class of a record. #[derive( Copy, Clone, - Debug, PartialEq, Eq, PartialOrd, @@ -281,13 +300,24 @@ impl RClass { pub const CH: Self = Self::new(3); } +//--- Formatting + +impl fmt::Debug for RClass { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match *self { + Self::IN => "RClass::IN", + Self::CH => "RClass::CH", + _ => return write!(f, "RClass({})", self.code), + }) + } +} + //----------- TTL ------------------------------------------------------------ /// How long a record can be cached. #[derive( Copy, Clone, - Debug, PartialEq, Eq, PartialOrd, @@ -322,6 +352,14 @@ impl From for u32 { } } +//--- Formatting + +impl fmt::Debug for TTL { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "TTL({})", self.value) + } +} + //----------- ParseRecordData ------------------------------------------------ /// Parsing DNS record data. diff --git a/src/new_base/wire/ints.rs b/src/new_base/wire/ints.rs index acd2990e2..a8170ea83 100644 --- a/src/new_base/wire/ints.rs +++ b/src/new_base/wire/ints.rs @@ -122,6 +122,12 @@ macro_rules! define_int { } } + impl fmt::Display for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } + } + //--- Arithmetic impl Add for $name { diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index f2cf6b710..d14cb6702 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -234,6 +234,7 @@ impl<'b> ParseBytes<'b> for EdnsOption<'b> { impl<'b> SplitBytes<'b> for EdnsOption<'b> { fn split_bytes(bytes: &'b [u8]) -> Result<(Self, &'b [u8]), ParseError> { + println!("Got bytes: {bytes:?}"); let (code, rest) = OptionCode::split_bytes(bytes)?; let (data, rest) = <&SizePrefixed<[u8]>>::split_bytes(rest)?; From 5f016e503d2db0e304eb134d5a79167cf9f2118d Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 4 Feb 2025 18:40:33 +0100 Subject: [PATCH 130/167] [new_base] Import various fixes from 'new-net-server' --- src/new_base/build/message.rs | 19 ++++++++----- src/new_base/build/question.rs | 8 +++--- src/new_base/build/record.rs | 8 +++--- src/new_base/name/reversed.rs | 36 ++++++++++++----------- src/new_base/parse/mod.rs | 2 +- src/new_base/question.rs | 52 ++++++++++++++++++++++++++++++++-- src/new_base/record.rs | 50 ++++++++++++++++++++++++++++---- src/new_base/wire/ints.rs | 6 ++++ src/new_edns/mod.rs | 3 +- 9 files changed, 141 insertions(+), 43 deletions(-) diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 5e969115f..0dc2bdd95 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -19,17 +19,17 @@ use super::{ /// This is a high-level building interface, offering methods to put together /// entire questions and records. It directly writes into an allocated buffer /// (on the stack or the heap). -pub struct MessageBuilder<'b> { +pub struct MessageBuilder<'b, 'c> { /// The message being constructed. pub(super) message: &'b mut Message, /// Context for building. - pub(super) context: &'b mut BuilderContext, + pub(super) context: &'c mut BuilderContext, } //--- Initialization -impl<'b> MessageBuilder<'b> { +impl<'b, 'c> MessageBuilder<'b, 'c> { /// Initialize an empty [`MessageBuilder`]. /// /// The message header is left uninitialized. use [`Self::header_mut()`] @@ -41,7 +41,7 @@ impl<'b> MessageBuilder<'b> { /// possible size for a DNS message). pub fn new( buffer: &'b mut [u8], - context: &'b mut BuilderContext, + context: &'c mut BuilderContext, ) -> Self { let message = Message::parse_bytes_by_mut(buffer) .expect("The caller's buffer is at least 12 bytes big"); @@ -52,7 +52,7 @@ impl<'b> MessageBuilder<'b> { //--- Inspection -impl MessageBuilder<'_> { +impl MessageBuilder<'_, '_> { /// The message header. pub fn header(&self) -> &Header { &self.message.header @@ -86,9 +86,14 @@ impl MessageBuilder<'_> { //--- Interaction -impl MessageBuilder<'_> { +impl<'b> MessageBuilder<'b, '_> { + /// End the builder, returning the built message. + pub fn finish(self) -> &'b Message { + self.message.slice_to(self.context.size) + } + /// Reborrow the builder with a shorter lifetime. - pub fn reborrow(&mut self) -> MessageBuilder<'_> { + pub fn reborrow(&mut self) -> MessageBuilder<'_, '_> { MessageBuilder { message: self.message, context: self.context, diff --git a/src/new_base/build/question.rs b/src/new_base/build/question.rs index 95fa095ae..addd72d91 100644 --- a/src/new_base/build/question.rs +++ b/src/new_base/build/question.rs @@ -19,7 +19,7 @@ use super::{BuildCommitted, BuildIntoMessage, MessageBuilder, MessageState}; /// commit (finish building) or cancel (remove) the question. pub struct QuestionBuilder<'b> { /// The underlying message builder. - builder: MessageBuilder<'b>, + builder: MessageBuilder<'b, 'b>, /// The offset of the question name. name: u16, @@ -33,7 +33,7 @@ impl<'b> QuestionBuilder<'b> { /// The provided builder must be empty (i.e. must not have uncommitted /// content). pub(super) fn build( - mut builder: MessageBuilder<'b>, + mut builder: MessageBuilder<'b, 'b>, question: &Question, ) -> Result { // TODO: Require that the QNAME serialize correctly? @@ -51,7 +51,7 @@ impl<'b> QuestionBuilder<'b> { /// `builder.message().contents[name..]` must represent a valid /// [`Question`] in the wire format. pub unsafe fn from_raw_parts( - builder: MessageBuilder<'b>, + builder: MessageBuilder<'b, 'b>, name: u16, ) -> Self { Self { builder, name } @@ -84,7 +84,7 @@ impl<'b> QuestionBuilder<'b> { } /// Deconstruct this [`QuestionBuilder`] into its raw parts. - pub fn into_raw_parts(self) -> (MessageBuilder<'b>, u16) { + pub fn into_raw_parts(self) -> (MessageBuilder<'b, 'b>, u16) { (self.builder, self.name) } } diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index f74418a13..7d9a48b5d 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -23,7 +23,7 @@ use super::{ /// cancel (remove) the record. pub struct RecordBuilder<'b> { /// The underlying message builder. - builder: MessageBuilder<'b>, + builder: MessageBuilder<'b, 'b>, /// The offset of the record name. name: u16, @@ -40,7 +40,7 @@ impl<'b> RecordBuilder<'b> { /// The provided builder must be empty (i.e. must not have uncommitted /// content). pub(super) fn build( - mut builder: MessageBuilder<'b>, + mut builder: MessageBuilder<'b, 'b>, record: &Record, ) -> Result where @@ -97,7 +97,7 @@ impl<'b> RecordBuilder<'b> { /// [`Record`] in the wire format. `contents[data..]` must represent the /// record data (i.e. immediately after the record data size field). pub unsafe fn from_raw_parts( - builder: MessageBuilder<'b>, + builder: MessageBuilder<'b, 'b>, name: u16, data: u16, ) -> Self { @@ -151,7 +151,7 @@ impl<'b> RecordBuilder<'b> { } /// Deconstruct this [`RecordBuilder`] into its raw parts. - pub fn into_raw_parts(self) -> (MessageBuilder<'b>, u16, u16) { + pub fn into_raw_parts(self) -> (MessageBuilder<'b, 'b>, u16, u16) { let (name, data) = (self.name, self.data); let this = ManuallyDrop::new(self); let this = (&*this) as *const Self; diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 551082a31..2edf9a75b 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -195,24 +195,20 @@ impl Hash for RevName { impl fmt::Debug for RevName { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - struct RevLabels<'a>(&'a RevName); - - impl fmt::Debug for RevLabels<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut first = true; - self.0.labels().try_for_each(|label| { - if !first { - f.write_str(".")?; - } else { - first = false; - } - - label.fmt(f) - }) + f.write_str("RevName(")?; + + let mut first = true; + self.labels().try_for_each(|label| { + if !first { + f.write_str(".")?; + } else { + first = false; } - } - f.debug_tuple("RevName").field(&RevLabels(self)).finish() + fmt::Display::fmt(&label, f) + })?; + + f.write_str(")") } } @@ -482,7 +478,7 @@ impl AsMut for RevNameBuf { } } -//--- Forwarding equality, comparison, and hashing +//--- Forwarding equality, comparison, hashing, and formatting impl PartialEq for RevNameBuf { fn eq(&self, that: &Self) -> bool { @@ -509,3 +505,9 @@ impl Hash for RevNameBuf { (**self).hash(state) } } + +impl fmt::Debug for RevNameBuf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} diff --git a/src/new_base/parse/mod.rs b/src/new_base/parse/mod.rs index 03ac16995..5a16fd17f 100644 --- a/src/new_base/parse/mod.rs +++ b/src/new_base/parse/mod.rs @@ -54,7 +54,7 @@ impl<'a, T: ?Sized + SplitBytesByRef> SplitMessageBytes<'a> for &'a T { start: usize, ) -> Result<(Self, usize), ParseError> { T::split_bytes_by_ref(&contents[start..]) - .map(|(this, rest)| (this, contents.len() - start - rest.len())) + .map(|(this, rest)| (this, contents.len() - rest.len())) } } diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 34e6dc282..c7a5203ce 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -1,5 +1,7 @@ //! DNS questions. +use core::fmt; + use domain_macros::*; use super::{ @@ -12,7 +14,7 @@ use super::{ //----------- Question ------------------------------------------------------- /// A DNS question. -#[derive(Clone, BuildBytes, ParseBytes, SplitBytes)] +#[derive(Clone, Debug, BuildBytes, ParseBytes, SplitBytes)] pub struct Question { /// The domain name being requested. pub qname: N, @@ -40,6 +42,19 @@ impl Question { } } +//--- Interaction + +impl Question { + /// Map the name in this question to another type. + pub fn map_name R>(self, f: F) -> Question { + Question { + qname: (f)(self.qname), + qtype: self.qtype, + qclass: self.qclass, + } + } +} + //--- Parsing from DNS messages impl<'a, N> SplitMessageBytes<'a> for Question @@ -59,6 +74,7 @@ where impl<'a, N> ParseMessageBytes<'a> for Question where + // TODO: Reduce to 'ParseMessageBytes'. N: SplitMessageBytes<'a>, { fn parse_message_bytes( @@ -95,7 +111,6 @@ where #[derive( Copy, Clone, - Debug, PartialEq, Eq, PartialOrd, @@ -154,13 +169,32 @@ impl QType { pub const AAAA: Self = Self::new(28); } +//--- Formatting + +impl fmt::Debug for QType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match *self { + Self::A => "QType::A", + Self::NS => "QType::NS", + Self::CNAME => "QType::CNAME", + Self::SOA => "QType::SOA", + Self::WKS => "QType::WKS", + Self::PTR => "QType::PTR", + Self::HINFO => "QType::HINFO", + Self::MX => "QType::MX", + Self::TXT => "QType::TXT", + Self::AAAA => "QType::AAAA", + _ => return write!(f, "QType({})", self.code), + }) + } +} + //----------- QClass --------------------------------------------------------- /// The class of a question. #[derive( Copy, Clone, - Debug, PartialEq, Eq, PartialOrd, @@ -194,3 +228,15 @@ impl QClass { /// The CHAOS class. pub const CH: Self = Self::new(3); } + +//--- Formatting + +impl fmt::Debug for QClass { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match *self { + Self::IN => "QClass::IN", + Self::CH => "QClass::CH", + _ => return write!(f, "QClass({})", self.code), + }) + } +} diff --git a/src/new_base/record.rs b/src/new_base/record.rs index badbab04e..85eefd732 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -1,6 +1,6 @@ //! DNS records. -use core::{borrow::Borrow, ops::Deref}; +use core::{borrow::Borrow, fmt, ops::Deref}; use super::{ build::{self, BuildIntoMessage, BuildResult}, @@ -15,7 +15,7 @@ use super::{ //----------- Record --------------------------------------------------------- /// A DNS record. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Record { /// The name of the record. pub rname: N, @@ -76,7 +76,7 @@ where let (_, rest) = <&SizePrefixed<[u8]>>::split_message_bytes(contents, rest)?; let rdata = - D::parse_record_data(&contents[..rest], rdata_start, rtype)?; + D::parse_record_data(&contents[..rest], rdata_start + 2, rtype)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) } @@ -178,7 +178,6 @@ where #[derive( Copy, Clone, - Debug, PartialEq, Eq, PartialOrd, @@ -240,13 +239,33 @@ impl RType { pub const OPT: Self = Self::new(41); } +//--- Formatting + +impl fmt::Debug for RType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match *self { + Self::A => "RType::A", + Self::NS => "RType::NS", + Self::CNAME => "RType::CNAME", + Self::SOA => "RType::SOA", + Self::WKS => "RType::WKS", + Self::PTR => "RType::PTR", + Self::HINFO => "RType::HINFO", + Self::MX => "RType::MX", + Self::TXT => "RType::TXT", + Self::AAAA => "RType::AAAA", + Self::OPT => "RType::OPT", + _ => return write!(f, "RType({})", self.code), + }) + } +} + //----------- RClass --------------------------------------------------------- /// The class of a record. #[derive( Copy, Clone, - Debug, PartialEq, Eq, PartialOrd, @@ -281,13 +300,24 @@ impl RClass { pub const CH: Self = Self::new(3); } +//--- Formatting + +impl fmt::Debug for RClass { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match *self { + Self::IN => "RClass::IN", + Self::CH => "RClass::CH", + _ => return write!(f, "RClass({})", self.code), + }) + } +} + //----------- TTL ------------------------------------------------------------ /// How long a record can be cached. #[derive( Copy, Clone, - Debug, PartialEq, Eq, PartialOrd, @@ -322,6 +352,14 @@ impl From for u32 { } } +//--- Formatting + +impl fmt::Debug for TTL { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "TTL({})", self.value) + } +} + //----------- ParseRecordData ------------------------------------------------ /// Parsing DNS record data. diff --git a/src/new_base/wire/ints.rs b/src/new_base/wire/ints.rs index acd2990e2..a8170ea83 100644 --- a/src/new_base/wire/ints.rs +++ b/src/new_base/wire/ints.rs @@ -122,6 +122,12 @@ macro_rules! define_int { } } + impl fmt::Display for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } + } + //--- Arithmetic impl Add for $name { diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index a0640dac1..d14cb6702 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -191,7 +191,7 @@ impl fmt::Debug for EdnsFlags { //----------- EdnsOption ----------------------------------------------------- /// An Extended DNS option. -#[derive(Debug)] +#[derive(Clone, Debug)] #[non_exhaustive] pub enum EdnsOption<'b> { /// A client's request for a DNS cookie. @@ -234,6 +234,7 @@ impl<'b> ParseBytes<'b> for EdnsOption<'b> { impl<'b> SplitBytes<'b> for EdnsOption<'b> { fn split_bytes(bytes: &'b [u8]) -> Result<(Self, &'b [u8]), ParseError> { + println!("Got bytes: {bytes:?}"); let (code, rest) = OptionCode::split_bytes(bytes)?; let (data, rest) = <&SizePrefixed<[u8]>>::split_bytes(rest)?; From 2c13124dee64c9f94fb322c4b55a10f59a706d6b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 6 Feb 2025 15:07:33 +0100 Subject: [PATCH 131/167] [new_server] Implement cookie middleware --- Cargo.toml | 2 +- src/new_edns/cookie.rs | 155 ++++++++--- src/new_edns/mod.rs | 67 ++++- src/new_server/exchange.rs | 131 ++++++++- src/new_server/layers/cookie.rs | 459 ++++++++++++++++++++++++++++++++ src/new_server/layers/mod.rs | 4 + src/new_server/mod.rs | 2 + src/new_server/transport/mod.rs | 37 ++- 8 files changed, 804 insertions(+), 53 deletions(-) create mode 100644 src/new_server/layers/cookie.rs create mode 100644 src/new_server/layers/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 6ea514ec9..b811b9250 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,7 +75,7 @@ zonefile = ["bytes", "serde", "std"] # Unstable features unstable-client-transport = ["moka", "net", "tracing"] -unstable-server-transport = ["dep:bumpalo", "arc-swap", "chrono/clock", "libc", "net", "siphasher", "tracing"] +unstable-server-transport = ["dep:bumpalo", "arc-swap", "chrono/clock", "libc", "net", "rand", "siphasher", "tracing"] unstable-sign = ["std", "dep:secrecy", "unstable-validate", "time/formatting"] unstable-stelline = ["tokio/test-util", "tracing", "tracing-subscriber", "tsig", "unstable-client-transport", "unstable-server-transport", "zonefile"] unstable-validate = ["bytes", "std", "ring"] diff --git a/src/new_edns/cookie.rs b/src/new_edns/cookie.rs index 77c810be5..0de4d065e 100644 --- a/src/new_edns/cookie.rs +++ b/src/new_edns/cookie.rs @@ -5,20 +5,22 @@ //! [RFC 7873]: https://datatracker.ietf.org/doc/html/rfc7873 //! [RFC 9018]: https://datatracker.ietf.org/doc/html/rfc9018 -use core::fmt; +use core::{ + borrow::{Borrow, BorrowMut}, + fmt, + hash::{Hash, Hasher}, + ops::{Deref, DerefMut}, +}; -#[cfg(all(feature = "std", feature = "siphasher"))] -use core::ops::Range; - -#[cfg(all(feature = "std", feature = "siphasher"))] -use std::net::IpAddr; +#[cfg(feature = "siphasher")] +use core::{net::IpAddr, ops::Range}; use domain_macros::*; -use crate::new_base::Serial; +use crate::new_base::{wire::ParseBytesByRef, Serial}; -#[cfg(all(feature = "std", feature = "siphasher"))] -use crate::new_base::wire::{AsBytes, TruncationError}; +#[cfg(feature = "siphasher")] +use crate::new_base::wire::AsBytes; //----------- ClientCookie --------------------------------------------------- @@ -57,35 +59,25 @@ impl ClientCookie { impl ClientCookie { /// Build a [`Cookie`] in response to this request. /// - /// A 24-byte version-1 interoperable cookie will be generated and written - /// to the given buffer. If the buffer is big enough, the remaining part - /// of the buffer is returned. - #[cfg(all(feature = "std", feature = "siphasher"))] - pub fn respond_into<'b>( - &self, - addr: IpAddr, - secret: &[u8; 16], - mut bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - use core::hash::Hasher; - + /// A 24-byte version-1 interoperable cookie will be returned. + #[cfg(feature = "siphasher")] + pub fn respond(&self, addr: IpAddr, secret: &[u8; 16]) -> CookieBuf { use siphasher::sip::SipHasher24; - use crate::new_base::wire::BuildBytes; + // Construct a buffer to write into. + let mut bytes = [0u8; 24]; - // Build and hash the cookie simultaneously. - let mut hasher = SipHasher24::new_with_key(secret); - - bytes = self.build_bytes(bytes)?; - hasher.write(self.as_bytes()); + bytes[0..8].copy_from_slice(self.as_bytes()); // The version number and the reserved octets. - bytes = [1, 0, 0, 0].build_bytes(bytes)?; - hasher.write(&[1, 0, 0, 0]); + bytes[8..12].copy_from_slice(&[1, 0, 0, 0]); let timestamp = Serial::unix_time(); - bytes = timestamp.build_bytes(bytes)?; - hasher.write(timestamp.as_bytes()); + bytes[12..16].copy_from_slice(timestamp.as_bytes()); + + // Hash the cookie. + let mut hasher = SipHasher24::new_with_key(secret); + hasher.write(&bytes[0..16]); match addr { IpAddr::V4(addr) => hasher.write(&addr.octets()), @@ -93,9 +85,11 @@ impl ClientCookie { } let hash = hasher.finish().to_le_bytes(); - bytes = hash.build_bytes(bytes)?; + bytes[16..24].copy_from_slice(&hash); - Ok(bytes) + let cookie = Cookie::parse_bytes_by_ref(&bytes) + .expect("Any 24-byte string is a valid 'Cookie'"); + CookieBuf::copy_from(cookie) } } @@ -194,7 +188,7 @@ impl Cookie { /// valid. /// /// [RFC 9018]: https://datatracker.ietf.org/doc/html/rfc9018 - #[cfg(all(feature = "std", feature = "siphasher"))] + #[cfg(feature = "siphasher")] pub fn verify( &self, addr: IpAddr, @@ -229,6 +223,99 @@ impl Cookie { } } +//----------- CookieBuf ------------------------------------------------------ + +/// A 41-byte buffer holding a [`Cookie`]. +#[derive(Clone)] +pub struct CookieBuf { + /// The size of the cookie, in bytes. + /// + /// This value is between 24 and 40, inclusive. + size: u8, + + /// The cookie data, as raw bytes. + data: [u8; 40], +} + +//--- Construction + +impl CookieBuf { + /// Copy a [`Cookie`] into a [`CookieBuf`]. + pub fn copy_from(cookie: &Cookie) -> Self { + let mut data = [0u8; 40]; + let cookie = cookie.as_bytes(); + data[..cookie.len()].copy_from_slice(cookie); + let size = cookie.len() as u8; + Self { size, data } + } +} + +//--- Access to the underlying 'Cookie' + +impl Deref for CookieBuf { + type Target = Cookie; + + fn deref(&self) -> &Self::Target { + let bytes = &self.data[..self.size as usize]; + // SAFETY: A 'CookieBuf' always contains a valid 'Cookie'. + unsafe { Cookie::parse_bytes_by_ref(bytes).unwrap_unchecked() } + } +} + +impl DerefMut for CookieBuf { + fn deref_mut(&mut self) -> &mut Self::Target { + let bytes = &mut self.data[..self.size as usize]; + // SAFETY: A 'CookieBuf' always contains a valid 'Cookie'. + unsafe { Cookie::parse_bytes_by_mut(bytes).unwrap_unchecked() } + } +} + +impl Borrow for CookieBuf { + fn borrow(&self) -> &Cookie { + self + } +} + +impl BorrowMut for CookieBuf { + fn borrow_mut(&mut self) -> &mut Cookie { + self + } +} + +impl AsRef for CookieBuf { + fn as_ref(&self) -> &Cookie { + self + } +} + +impl AsMut for CookieBuf { + fn as_mut(&mut self) -> &mut Cookie { + self + } +} + +//--- Forwarding formatting, equality and hashing + +impl fmt::Debug for CookieBuf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl PartialEq for CookieBuf { + fn eq(&self, that: &Self) -> bool { + **self == **that + } +} + +impl Eq for CookieBuf {} + +impl Hash for CookieBuf { + fn hash(&self, state: &mut H) { + (**self).hash(state) + } +} + //----------- CookieError ---------------------------------------------------- /// An invalid [`Cookie`] was encountered. diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index d14cb6702..89d1074af 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -8,19 +8,21 @@ use domain_macros::*; use crate::{ new_base::{ + name::RevName, parse::{ParseMessageBytes, SplitMessageBytes}, wire::{ AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, SizePrefixed, SplitBytes, TruncationError, U16, }, + RClass, RType, Record, }, - new_rdata::Opt, + new_rdata::{Opt, RecordData}, }; //----------- EDNS option modules -------------------------------------------- mod cookie; -pub use cookie::{ClientCookie, Cookie}; +pub use cookie::{ClientCookie, Cookie, CookieBuf, CookieError}; mod ext_err; pub use ext_err::{ExtError, ExtErrorCode}; @@ -46,6 +48,51 @@ pub struct EdnsRecord<'a> { pub options: SizePrefixed<&'a Opt>, } +//--- Converting to and from 'Record' + +impl<'n, 'a, DN> TryFrom>> + for EdnsRecord<'a> +{ + type Error = ParseError; + + fn try_from( + value: Record<&'n RevName, RecordData<'a, DN>>, + ) -> Result { + if !value.rname.is_root() || value.rtype != RType::OPT { + return Err(ParseError); + } + + let RecordData::Opt(opt) = value.rdata else { + return Err(ParseError); + }; + + let ttl = value.ttl.value.get().to_be_bytes(); + Ok(Self { + max_udp_payload: value.rclass.code, + ext_rcode: ttl[0], + version: ttl[1], + flags: u16::from_be_bytes([ttl[2], ttl[3]]).into(), + options: SizePrefixed::new(opt), + }) + } +} + +impl<'a, DN> From> for Record<&RevName, RecordData<'a, DN>> { + fn from(value: EdnsRecord<'a>) -> Self { + let flags = value.flags.bits().to_be_bytes(); + let ttl = [value.ext_rcode, value.version, flags[0], flags[1]]; + Record { + rname: RevName::ROOT, + rtype: RType::OPT, + rclass: RClass { + code: value.max_udp_payload, + }, + ttl: u32::from_be_bytes(ttl).into(), + rdata: RecordData::Opt(*value.options), + } + } +} + //--- Parsing from DNS messages impl<'a> SplitMessageBytes<'a> for EdnsRecord<'a> { @@ -177,6 +224,22 @@ impl EdnsFlags { } } +//--- Conversion to and from integers + +impl From for EdnsFlags { + fn from(value: u16) -> Self { + Self { + inner: U16::new(value), + } + } +} + +impl From for u16 { + fn from(value: EdnsFlags) -> Self { + value.inner.get() + } +} + //--- Formatting impl fmt::Debug for EdnsFlags { diff --git a/src/new_server/exchange.rs b/src/new_server/exchange.rs index 20b1621f5..9233aa461 100644 --- a/src/new_server/exchange.rs +++ b/src/new_server/exchange.rs @@ -15,10 +15,10 @@ use crate::{ build::{BuilderContext, MessageBuilder}, name::{RevName, RevNameBuf}, parse::SplitMessageBytes, - wire::{BuildBytes, ParseError, TruncationError, U16}, - HeaderFlags, Message, Question, RType, Record, + wire::{BuildBytes, ParseError, SizePrefixed, TruncationError, U16}, + HeaderFlags, Message, Question, RType, Record, SectionCounts, }, - new_edns::EdnsOption, + new_edns::{EdnsOption, EdnsRecord}, new_rdata::{Opt, RecordData}, }; @@ -47,6 +47,13 @@ pub struct Exchange<'a> { pub metadata: Vec, } +impl Exchange<'_> { + /// Begin a response with the given code. + pub fn respond(&mut self, code: ResponseCode) { + self.response.respond_to(&self.request, code); + } +} + //----------- OutgoingResponse ----------------------------------------------- /// An [`Exchange`] with an initialized response message. @@ -159,11 +166,8 @@ impl<'a> ParsedMessage<'a> { offset, )?; - this.questions.push(Question { - qname: map_name(question.qname, alloc), - qtype: question.qtype, - qclass: question.qclass, - }); + this.questions + .push(question.map_name(|n| map_name(n, alloc))); offset = rest; } @@ -266,10 +270,7 @@ impl<'a> ParsedMessage<'a> { let header = builder.header_mut(); header.id = self.id; header.flags = self.flags; - header.counts.questions.set(self.questions.len() as u16); - header.counts.answers.set(self.answers.len() as u16); - header.counts.authorities.set(self.authorities.len() as u16); - header.counts.additional.set(self.additional.len() as u16); + header.counts = SectionCounts::default(); // Build the question section. for question in &self.questions { @@ -309,7 +310,7 @@ impl<'a> ParsedMessage<'a> { let uninit_len = uninit.len(); let appended = delegate.uninitialized().len() - uninit_len; delegate.mark_appended(appended); - core::mem::drop(delegate); + delegate.commit(); builder.commit(); edns_built = true; @@ -328,6 +329,13 @@ impl<'a> ParsedMessage<'a> { } } +impl ParsedMessage<'_> { + /// Whether this message has an EDNS record. + pub fn has_edns(&self) -> bool { + self.additional.iter().any(|r| r.rtype == RType::OPT) + } +} + impl ParsedMessage<'_> { /// Reset this object to a blank message. /// @@ -341,6 +349,103 @@ impl ParsedMessage<'_> { self.additional.clear(); self.options.clear(); } + + /// Begin a new message in response to the given one. + /// + /// The contents of `self` will be overwritten. The message ID and flags + /// will be copied from the request message, and the given response code + /// will be set. The OPT record, if any, will also be copied (without any + /// EDNS options); if an extended response code is used, it will be added. + pub fn respond_to( + &mut self, + request: &ParsedMessage<'_>, + code: ResponseCode, + ) { + self.reset(); + self.id = request.id; + self.flags = request.flags.respond(code.header_bits()); + + if let Some(edns) = request + .additional + .iter() + .find_map(|r| EdnsRecord::try_from(r.clone()).ok()) + { + // Copy the EDNS record, without any options. + let record = EdnsRecord { + max_udp_payload: edns.max_udp_payload, + ext_rcode: edns.ext_rcode, + version: edns.version, + flags: edns.flags, + options: SizePrefixed::new(Opt::EMPTY), + }; + self.additional.push(record.into()); + } + } +} + +//----------- ResponseCode --------------------------------------------------- + +/// A (possibly extended) DNS response code. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum ResponseCode { + /// The request was answered successfully. + Success, + + /// The request was misformatted. + FormatError, + + /// The server encountered an internal error. + ServerFailure, + + /// The queried domain name does not exist. + NonExistentDomain, + + /// The server does not support the requested kind of query. + NotImplemented, + + /// Policy prevents the server from answering the query. + Refused, + + /// The TSIG record in the request was invalid. + InvalidTSIG, + + /// The server does not support the request's OPT record version. + UnsupportedOptVersion, + + /// The request did not contain a valid EDNS server cookie. + BadCookie, +} + +impl ResponseCode { + /// This code's representation in the DNS message header. + pub const fn header_bits(&self) -> u8 { + match self { + Self::Success => 0, + Self::FormatError => 1, + Self::ServerFailure => 2, + Self::NonExistentDomain => 3, + Self::NotImplemented => 4, + Self::Refused => 5, + Self::InvalidTSIG => 9, + Self::UnsupportedOptVersion => 0, + Self::BadCookie => 7, + } + } + + /// This code's representation in the EDNS record header. + pub const fn edns_bits(&self) -> u8 { + match self { + Self::Success => 0, + Self::FormatError => 0, + Self::ServerFailure => 0, + Self::NonExistentDomain => 0, + Self::NotImplemented => 0, + Self::Refused => 0, + Self::InvalidTSIG => 0, + Self::UnsupportedOptVersion => 1, + Self::BadCookie => 1, + } + } } //----------- Metadata ------------------------------------------------------- diff --git a/src/new_server/layers/cookie.rs b/src/new_server/layers/cookie.rs new file mode 100644 index 000000000..a54700f15 --- /dev/null +++ b/src/new_server/layers/cookie.rs @@ -0,0 +1,459 @@ +//! DNS cookie management. + +use core::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + ops::{ControlFlow, Range}, +}; + +use std::{sync::Arc, vec::Vec}; + +use arc_swap::ArcSwap; +use rand::{CryptoRng, Rng, RngCore}; + +use crate::{ + new_base::{ + wire::{AsBytes, ParseBytesByRef}, + Serial, + }, + new_edns::{ + ClientCookie, Cookie, CookieBuf, CookieError, EdnsOption, OptionCode, + }, + new_server::{ + exchange::{Metadata, OutgoingResponse, ResponseCode}, + transport::{SourceIpAddr, UdpMetadata}, + Exchange, LocalServiceLayer, ServiceLayer, + }, +}; + +//----------- CookieLayer ---------------------------------------------------- + +/// Server-side DNS cookie management. +#[derive(Debug)] +pub struct CookieLayer { + /// The cookie policy to use. + policy: ArcSwap, + + /// The secrets to use for signing and verifying. + secrets: ArcSwap, +} + +//--- Interaction + +impl CookieLayer { + /// Construct a new [`CookieLayer`]. + pub fn new(policy: CookiePolicy, secrets: CookieSecrets) -> Self { + Self { + policy: ArcSwap::new(Arc::new(policy)), + secrets: ArcSwap::new(Arc::new(secrets)), + } + } + + /// Load the cookie policy. + /// + /// The current state of the policy is loaded. The policy may be changed + /// by a different thread, so future calls to the method may result in + /// different policies. + pub fn get_policy(&self) -> Arc { + self.policy.load_full() + } + + /// Replace the cookie policy. + /// + /// This will atomically update the policy, so that future callers of + /// [`get_policy()`](Self::get_policy()) will (soon but not necessarily + /// immediately) see the updated policy. + pub fn set_policy(&self, policy: CookiePolicy) { + self.policy.store(Arc::new(policy)); + } + + /// Load the cookie secrets. + /// + /// The current state of the secrets is loaded. The secrets may be + /// changed by a different thread, so future calls to the method may + /// result in different secrets. + pub fn get_secrets(&self) -> Arc { + self.secrets.load_full() + } + + /// Replace the cookie secrets. + /// + /// This will atomically update the secrets, so that future callers of + /// [`get_secrets()`](Self::get_secrets()) will (soon but not necessarily + /// immediately) see the updated secrets. + pub fn set_secrets(&self, secrets: CookieSecrets) { + self.secrets.store(Arc::new(secrets)); + } +} + +//--- Processing incoming requests + +impl CookieLayer { + /// Respond to an incoming request with an alleged server cookie. + fn process_incoming_server_cookie<'a>( + &self, + exchange: &mut Exchange<'a>, + addr: IpAddr, + cookie: &'a Cookie, + ) -> ControlFlow<()> { + // Determine the validity period of the cookie. + let now = Serial::unix_time(); + let validity = now + -300..now + 3600; + + // Check if the cookie is actually valid. + if self.secrets.load().verify(&addr, validity, cookie).is_err() { + // Simply ignore the server part. + return self.process_incoming_wo_server_cookie( + exchange, + addr, + Some(cookie.request()), + ); + } + + // Determine whether the cookie needs to be renewed. + let expiry = now + 1800; + let regenerate = cookie.timestamp() >= expiry; + + // Remember the cookie status. + let cookie = CookieBuf::copy_from(cookie); + let metadata = CookieMetadata::ServerCookie { cookie, regenerate }; + exchange.metadata.push(Metadata::new(metadata)); + + // Continue into the next layer. + ControlFlow::Continue(()) + } + + /// Respond to an incoming request without a (valid) server cookie. + fn process_incoming_wo_server_cookie<'a>( + &self, + exchange: &mut Exchange<'a>, + addr: IpAddr, + cookie: Option<&'a ClientCookie>, + ) -> ControlFlow<()> { + // RFC 7873, section 5.2.3: + // + // > Servers MUST, at least occasionally, respond to such requests to + // > inform the client of the correct Server Cookie. This is + // > necessary so that such a client can bootstrap to the more secure + // > state where requests and responses have recognized Server Cookies + // > and Client Cookies. A server is not expected to maintain + // > per-client state to achieve this. For example, it could respond + // > to every Nth request across all clients. + + // We rate-limit requests based on the cookie policy. If the request + // originates from a restricted IP address, the request is allowed to + // continue with a small probability. All requests from unrestricted + // IP addresses are allowed to go through. All non-UDP requests are + // allowed to go through anyway. + if !exchange.metadata.iter().any(|m| m.is::()) + || !self.policy.load().is_required_for(addr) + || rand::thread_rng().gen_bool(0.05) + { + // The request is allowed to go through. + let metadata = match cookie { + Some(&cookie) => CookieMetadata::ClientCookie(cookie), + None => CookieMetadata::None, + }; + exchange.metadata.push(Metadata::new(metadata)); + return ControlFlow::Continue(()); + } + + // Block the request. + if exchange.request.has_edns() { + exchange.respond(ResponseCode::BadCookie); + } else { + exchange.respond(ResponseCode::Refused); + exchange.response.flags = + exchange.response.flags.set_truncated(true); + } + exchange + .response + .questions + .append(&mut exchange.request.questions); + + ControlFlow::Break(()) + } +} + +//--- Processing outgoing responses + +impl CookieLayer { + /// Generate an EDNS COOKIE option for a response. + fn generate_cookie( + &self, + addr: IpAddr, + cookie: ClientCookie, + ) -> CookieBuf { + cookie.respond(addr, &self.secrets.load().primary) + } +} + +//--- ServiceLayer + +impl ServiceLayer for CookieLayer { + async fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + // Check for an EDNS COOKIE option. + let cookie = exchange + .request + .options + .iter() + .find(|option| option.code() == OptionCode::COOKIE) + .cloned(); + + // Determine the IP address the request originated from. + let Some(&SourceIpAddr(addr)) = + exchange.metadata.iter().find_map(|m| m.try_as()) + else { + // We couldn't determine the source address. + // TODO: This is unexpected, log it. + return ControlFlow::Continue(()); + }; + + match cookie { + Some(EdnsOption::Cookie(cookie)) => { + self.process_incoming_server_cookie(exchange, addr, cookie) + } + + Some(EdnsOption::ClientCookie(cookie)) => self + .process_incoming_wo_server_cookie( + exchange, + addr, + Some(cookie), + ), + + None => { + self.process_incoming_wo_server_cookie(exchange, addr, None) + } + + _ => unreachable!(), + } + } + + async fn process_outgoing(&self, response: OutgoingResponse<'_, '_>) { + // Determine the IP address the request originated from. + let Some(&SourceIpAddr(addr)) = + response.metadata.iter().find_map(|m| m.try_as()) + else { + // We couldn't determine the source address. + // TODO: This is unexpected, log it. + return; + }; + + // Check for cookie metadata. + let cookie = match response.metadata.iter().find_map(|m| m.try_as()) { + // The request had a client cookie (and possibly an invalid server + // cookie). Generate a new server cookie and include it. + Some(CookieMetadata::ClientCookie(cookie)) => { + self.generate_cookie(addr, *cookie) + } + + // The request had a server cookie that may need to be renewed. + Some(CookieMetadata::ServerCookie { cookie, regenerate }) => { + if *regenerate { + self.generate_cookie(addr, *cookie.request()) + } else { + cookie.clone() + } + } + + // The request did not contain a cookie, or the cookie layer was + // disabled when answering this request. + Some(CookieMetadata::None) | None => return, + }; + + // Copy the cookie into the response. + // TODO: Check that the response includes an EDNS record. + let cookie = response.alloc.alloc_slice_copy((*cookie).as_bytes()); + let cookie = Cookie::parse_bytes_by_ref(cookie).unwrap(); + let option = EdnsOption::Cookie(cookie); + response.response.options.push(option); + } +} + +impl LocalServiceLayer for CookieLayer { + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + self.process_incoming(exchange).await + } + + async fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) { + self.process_outgoing(response).await + } +} + +//----------- CookiePolicy --------------------------------------------------- + +/// Configuration for DNS cookie enforcement. +#[derive(Clone, Debug, Default)] +pub struct CookiePolicy { + /// IP addresses that must provide DNS cookies with their queries. + pub required: PrefixTree, + + /// IP addresses that need not provide DNS cookies with their queries. + pub allowed: PrefixTree, +} + +impl CookiePolicy { + /// Whether an IP address is required to use DNS cookies. + pub fn is_required_for(&self, addr: IpAddr) -> bool { + match (self.required.test(addr), self.allowed.test(addr)) { + // The address is restricted, but is more specifically allowed. + (Some(r), Some(a)) if a >= r => false, + + // The address is definitely restricted. + (Some(_), _) => true, + + // There are no restrictions on the address. + (None, _) => true, + } + } +} + +//----------- CookieSecrets -------------------------------------------------- + +/// The secrets used for DNS cookies. +#[derive(Clone, Debug)] +pub struct CookieSecrets { + /// The primary secret (used for generation and verification). + pub primary: [u8; 16], + + /// A secondary secret for verification. + pub secondary: [u8; 16], +} + +impl CookieSecrets { + /// Initialize [`CookieSecrets`] with a random primary. + pub fn generate() -> Self { + Self::generate_with(rand::thread_rng()) + } + + /// Initialize [`CookieSecrets`] with the given RNG. + pub fn generate_with(mut rng: impl CryptoRng + RngCore) -> Self { + let primary = rng.gen(); + Self { + primary, + secondary: primary, + } + } + + /// Verify the given cookie against these secrets. + fn verify( + &self, + addr: &IpAddr, + validity: Range, + cookie: &Cookie, + ) -> Result<(), CookieError> { + let Err(err) = cookie.verify(*addr, &self.primary, validity.clone()) + else { + return Ok(()); + }; + + // TODO: Compare secrets more carefully. + if self.primary == self.secondary { + return Err(err); + } + + cookie.verify(*addr, &self.secondary, validity) + } +} + +//----------- CookieMetadata ------------------------------------------------- + +/// Information about a DNS request's use of cookies. +pub enum CookieMetadata { + /// The request did not use DNS cookies. + None, + + /// The request included a DNS client cookie. + ClientCookie(ClientCookie), + + /// The request included a DNS server cookie. + ServerCookie { + /// The cookie used in the request. + cookie: CookieBuf, + + /// Whether a new cookie should be generated. + regenerate: bool, + }, +} + +//----------- PrefixTree ----------------------------------------------------- + +/// A set of IP addresses represented as prefixes. +#[derive(Clone, Debug, Default)] +pub struct PrefixTree { + /// A list of v4 prefixes, from longest to shortest. + v4_prefixes: Vec<(u8, Ipv4Addr)>, + + /// A list of v6 prefixes, from longest to shortest. + v6_prefixes: Vec<(u8, Ipv6Addr)>, +} + +impl PrefixTree { + /// Build a [`PrefixTree`] from an unsorted list of prefixes. + /// + /// The prefixes will be sorted before being used. Outside the valid + /// length of each prefix, only zero bits must be used. + pub fn from_prefixes( + mut v4_prefixes: Vec<(u8, Ipv4Addr)>, + mut v6_prefixes: Vec<(u8, Ipv6Addr)>, + ) -> Self { + v4_prefixes.sort_unstable_by(|a, b| a.0.cmp(&b.0).reverse()); + v6_prefixes.sort_unstable_by(|a, b| a.0.cmp(&b.0).reverse()); + Self::from_sorted_prefixes(v4_prefixes, v6_prefixes) + } + + /// Build a [`PrefixTree`] from a sorted list of prefixes. + /// + /// The prefixes must be sorted from longest to shortest. Within a + /// particular prefix length, the addresses are unordered. Outside the + /// valid length of each prefix, only zero bits must be used. + pub fn from_sorted_prefixes( + v4_prefixes: Vec<(u8, Ipv4Addr)>, + v6_prefixes: Vec<(u8, Ipv6Addr)>, + ) -> Self { + Self { + v4_prefixes, + v6_prefixes, + } + } + + /// Test whether an IP address is in this prefix tree. + /// + /// If a matching prefix is found, its length is returned. + pub fn test(&self, addr: IpAddr) -> Option { + match addr { + IpAddr::V4(addr) => self.test_v4(addr), + IpAddr::V6(addr) => self.test_v6(addr), + } + } + + /// Test whether an IPv4 address is in this prefix tree. + /// + /// If a matching prefix is found, its length is returned. + pub fn test_v4(&self, addr: Ipv4Addr) -> Option { + self.v4_prefixes + .iter() + .copied() + .find(|(_, prefix)| (prefix & addr) == *prefix) + .map(|(length, _)| length) + } + + /// Test whether an IPv6 address is in this prefix tree. + /// + /// If a matching prefix is found, its length is returned. + pub fn test_v6(&self, addr: Ipv6Addr) -> Option { + self.v6_prefixes + .iter() + .copied() + .find(|(_, prefix)| (prefix & addr) == *prefix) + .map(|(length, _)| length) + } +} diff --git a/src/new_server/layers/mod.rs b/src/new_server/layers/mod.rs new file mode 100644 index 000000000..885e64001 --- /dev/null +++ b/src/new_server/layers/mod.rs @@ -0,0 +1,4 @@ +//! Common plug-in functionality for DNS servers. + +pub mod cookie; +pub use cookie::CookieLayer; diff --git a/src/new_server/mod.rs b/src/new_server/mod.rs index ce8bd749e..d9aef6c1a 100644 --- a/src/new_server/mod.rs +++ b/src/new_server/mod.rs @@ -24,6 +24,8 @@ use exchange::OutgoingResponse; pub mod transport; +pub mod layers; + //----------- Service -------------------------------------------------------- /// A (multi-threaded) DNS service, that computes responses for requests. diff --git a/src/new_server/transport/mod.rs b/src/new_server/transport/mod.rs index 4bb193633..8cb7012d0 100644 --- a/src/new_server/transport/mod.rs +++ b/src/new_server/transport/mod.rs @@ -1,6 +1,6 @@ //! Network transports for DNS servers. -use core::net::SocketAddr; +use core::net::{IpAddr, SocketAddr}; use std::{io, sync::Arc, time::SystemTime, vec::Vec}; use bumpalo::Bump; @@ -11,7 +11,7 @@ use crate::{ wire::{AsBytes, ParseBytesByRef}, Message, }, - new_server::exchange::Allocator, + new_server::exchange::{Allocator, Metadata}, }; use super::{exchange::ParsedMessage, Exchange, Service}; @@ -67,18 +67,20 @@ pub async fn serve_udp( reception: SystemTime::now(), request, response: ParsedMessage::default(), - metadata: Vec::new(), + metadata: vec![Metadata::new(SourceIpAddr(peer.ip()))], }; // Generate the appropriate response. self.service.respond(&mut exchange).await; // Build up the response message. + println!("Returning response: {:?}", exchange.response); let mut buffer = vec![0u8; 65536]; let message = exchange.response.build(&mut buffer).unwrap_or_else(|_| { todo!("how to handle truncation errors?") }); + println!("In bytes: {:?}", message.as_bytes()); // Send the response back to the peer. let _ = self.socket.send_to(message.as_bytes(), peer).await; @@ -104,3 +106,32 @@ pub async fn serve_udp( tokio::task::spawn(state.clone().respond(buffer, peer)); } } + +//----------- SourceIpAddr --------------------------------------------------- + +/// The IP address a DNS request originated from. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct SourceIpAddr(pub IpAddr); + +impl From for SourceIpAddr { + fn from(value: IpAddr) -> Self { + Self(value) + } +} + +impl From for IpAddr { + fn from(value: SourceIpAddr) -> Self { + value.0 + } +} + +//----------- UdpMetadata ---------------------------------------------------- + +/// Information about a DNS request on a UDP socket. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct UdpMetadata { + /// The UDP port the request originated from. + /// + /// Use [`SourceIpAddr`] to determine the associated IP address. + pub port: u16, +} From 76655504e2bdc4ad2b6e415805245c790e34f135 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 6 Feb 2025 15:09:29 +0100 Subject: [PATCH 132/167] [new_edns] Remove debug 'println' --- src/new_edns/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index d14cb6702..f2cf6b710 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -234,7 +234,6 @@ impl<'b> ParseBytes<'b> for EdnsOption<'b> { impl<'b> SplitBytes<'b> for EdnsOption<'b> { fn split_bytes(bytes: &'b [u8]) -> Result<(Self, &'b [u8]), ParseError> { - println!("Got bytes: {bytes:?}"); let (code, rest) = OptionCode::split_bytes(bytes)?; let (data, rest) = <&SizePrefixed<[u8]>>::split_bytes(rest)?; From a4c7d49a8bfbb423df3e375a66649e3dc29dd5a4 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 6 Feb 2025 17:27:03 +0100 Subject: [PATCH 133/167] [new_base] Add some basic unit tests --- src/new_base/charstr.rs | 29 +++++++++++++++++++++++++++++ src/new_base/question.rs | 34 +++++++++++++++++++++++++++++++++- src/new_base/record.rs | 36 +++++++++++++++++++++++++++++++++++- src/new_base/serial.rs | 22 ++++++++++++++++++++++ 4 files changed, 119 insertions(+), 2 deletions(-) diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index 3eab4d7cd..adf8e6024 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -141,3 +141,32 @@ impl fmt::Debug for CharStr { .finish() } } + +//============ Tests ========================================================= + +#[cfg(test)] +mod test { + use super::CharStr; + + use crate::new_base::wire::{ + BuildBytes, ParseBytes, ParseError, SplitBytes, + }; + + #[test] + fn parse_build() { + let bytes = b"\x05Hello!"; + let (charstr, rest) = <&CharStr>::split_bytes(bytes).unwrap(); + assert_eq!(&charstr.octets, b"Hello"); + assert_eq!(rest, b"!"); + + assert_eq!(<&CharStr>::parse_bytes(bytes), Err(ParseError)); + assert!(<&CharStr>::parse_bytes(&bytes[..6]).is_ok()); + + let mut buffer = [0u8; 6]; + assert_eq!( + charstr.build_bytes(&mut buffer), + Ok(&mut [] as &mut [u8]) + ); + assert_eq!(buffer, &bytes[..6]); + } +} diff --git a/src/new_base/question.rs b/src/new_base/question.rs index c7a5203ce..82ea76f25 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -14,7 +14,7 @@ use super::{ //----------- Question ------------------------------------------------------- /// A DNS question. -#[derive(Clone, Debug, BuildBytes, ParseBytes, SplitBytes)] +#[derive(Clone, Debug, BuildBytes, ParseBytes, SplitBytes, PartialEq, Eq)] pub struct Question { /// The domain name being requested. pub qname: N, @@ -240,3 +240,35 @@ impl fmt::Debug for QClass { }) } } + +//============ Tests ========================================================= + +#[cfg(test)] +mod test { + use super::{QClass, QType, Question}; + + use crate::new_base::{ + name::Name, + wire::{BuildBytes, ParseBytes, ParseError, SplitBytes}, + }; + + #[test] + fn parse_build() { + let bytes = b"\x03com\x00\x00\x01\x00\x01\x2A"; + let (question, rest) = >::split_bytes(bytes).unwrap(); + assert_eq!(question.qname.as_bytes(), b"\x03com\x00"); + assert_eq!(question.qtype, QType::A); + assert_eq!(question.qclass, QClass::IN); + assert_eq!(rest, b"\x2A"); + + assert_eq!(>::parse_bytes(bytes), Err(ParseError)); + assert!(>::parse_bytes(&bytes[..9]).is_ok()); + + let mut buffer = [0u8; 9]; + assert_eq!( + question.build_bytes(&mut buffer), + Ok(&mut [] as &mut [u8]) + ); + assert_eq!(buffer, &bytes[..9]); + } +} diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 85eefd732..dd0771564 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -15,7 +15,7 @@ use super::{ //----------- Record --------------------------------------------------------- /// A DNS record. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct Record { /// The name of the record. pub rname: N, @@ -448,3 +448,37 @@ impl AsRef<[u8]> for UnparsedRecordData { self } } + +//============ Tests ========================================================= + +#[cfg(test)] +mod test { + use super::{RClass, RType, Record, UnparsedRecordData, TTL}; + + use crate::new_base::{ + name::Name, + wire::{AsBytes, BuildBytes, ParseBytes, SplitBytes}, + }; + + #[test] + fn parse_build() { + type Subject<'a> = Record<&'a Name, &'a UnparsedRecordData>; + + let bytes = + b"\x03com\x00\x00\x01\x00\x01\x00\x00\x00\x2A\x00\x00\x54"; + let (record, rest) = Subject::split_bytes(bytes).unwrap(); + assert_eq!(record.rname.as_bytes(), b"\x03com\x00"); + assert_eq!(record.rtype, RType::A); + assert_eq!(record.rclass, RClass::IN); + assert_eq!(record.ttl, TTL::from(42)); + assert_eq!(record.rdata.as_bytes(), b""); + assert_eq!(rest, b"\x54"); + + assert!(Subject::parse_bytes(bytes).is_err()); + assert!(Subject::parse_bytes(&bytes[..15]).is_ok()); + + let mut buffer = [0u8; 15]; + assert_eq!(record.build_bytes(&mut buffer), Ok(&mut [] as &mut [u8])); + assert_eq!(buffer, &bytes[..15]); + } +} diff --git a/src/new_base/serial.rs b/src/new_base/serial.rs index af0e4a1a1..2e1b6583c 100644 --- a/src/new_base/serial.rs +++ b/src/new_base/serial.rs @@ -102,3 +102,25 @@ impl fmt::Display for Serial { self.0.get().fmt(f) } } + +//============ Tests ========================================================= + +#[cfg(test)] +mod test { + use super::Serial; + + #[test] + fn comparisons() { + // TODO: Use property-based testing. + assert!(Serial::from(u32::MAX) > Serial::from(u32::MAX / 2 + 1)); + assert!(Serial::from(0) > Serial::from(u32::MAX)); + assert!(Serial::from(1) > Serial::from(0)); + } + + #[test] + fn operations() { + // TODO: Use property-based testing. + assert_eq!(u32::from(Serial::from(1) + 1), 2); + assert_eq!(u32::from(Serial::from(u32::MAX) + 1), 0); + } +} From 2219209e03cea8d573376701b078656ec7bd30d9 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 7 Feb 2025 09:20:35 +0100 Subject: [PATCH 134/167] [new_edns/cookie] Import 'AsBytes' unconditionally --- src/new_edns/cookie.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/new_edns/cookie.rs b/src/new_edns/cookie.rs index 0de4d065e..9b7c5a243 100644 --- a/src/new_edns/cookie.rs +++ b/src/new_edns/cookie.rs @@ -17,10 +17,10 @@ use core::{net::IpAddr, ops::Range}; use domain_macros::*; -use crate::new_base::{wire::ParseBytesByRef, Serial}; - -#[cfg(feature = "siphasher")] -use crate::new_base::wire::AsBytes; +use crate::new_base::{ + wire::{AsBytes, ParseBytesByRef}, + Serial, +}; //----------- ClientCookie --------------------------------------------------- From 0be40c47d67477eb04e9bcff79cd0f890c23fcde Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 7 Feb 2025 09:37:52 +0100 Subject: [PATCH 135/167] [new_edns] Remove leftover debugging --- src/new_edns/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index 89d1074af..fec98291d 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -297,7 +297,6 @@ impl<'b> ParseBytes<'b> for EdnsOption<'b> { impl<'b> SplitBytes<'b> for EdnsOption<'b> { fn split_bytes(bytes: &'b [u8]) -> Result<(Self, &'b [u8]), ParseError> { - println!("Got bytes: {bytes:?}"); let (code, rest) = OptionCode::split_bytes(bytes)?; let (data, rest) = <&SizePrefixed<[u8]>>::split_bytes(rest)?; From 54a42962956d31491f8a5acfc16a661c4ac16600 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 7 Feb 2025 10:24:25 +0100 Subject: [PATCH 136/167] [new_server] Use 'log' --- Cargo.toml | 2 +- src/new_server/layers/cookie.rs | 25 +++++++++++++++++++++++++ src/new_server/transport/mod.rs | 11 +++++++++-- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b811b9250..c934b8215 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,7 +75,7 @@ zonefile = ["bytes", "serde", "std"] # Unstable features unstable-client-transport = ["moka", "net", "tracing"] -unstable-server-transport = ["dep:bumpalo", "arc-swap", "chrono/clock", "libc", "net", "rand", "siphasher", "tracing"] +unstable-server-transport = ["dep:bumpalo", "arc-swap", "chrono/clock", "dep:log", "libc", "net", "rand", "siphasher", "tracing"] unstable-sign = ["std", "dep:secrecy", "unstable-validate", "time/formatting"] unstable-stelline = ["tokio/test-util", "tracing", "tracing-subscriber", "tsig", "unstable-client-transport", "unstable-server-transport", "zonefile"] unstable-validate = ["bytes", "std", "ring"] diff --git a/src/new_server/layers/cookie.rs b/src/new_server/layers/cookie.rs index a54700f15..7bd8e00aa 100644 --- a/src/new_server/layers/cookie.rs +++ b/src/new_server/layers/cookie.rs @@ -8,6 +8,7 @@ use core::{ use std::{sync::Arc, vec::Vec}; use arc_swap::ArcSwap; +use log::trace; use rand::{CryptoRng, Rng, RngCore}; use crate::{ @@ -101,6 +102,10 @@ impl CookieLayer { // Check if the cookie is actually valid. if self.secrets.load().verify(&addr, validity, cookie).is_err() { + trace!(target: "CookieLayer", + "Ignoring invalid server cookie in request {}", + exchange.request.id); + // Simply ignore the server part. return self.process_incoming_wo_server_cookie( exchange, @@ -109,6 +114,10 @@ impl CookieLayer { ); } + trace!(target: "CookieLayer", + "Validated server cookie in request {}", + exchange.request.id); + // Determine whether the cookie needs to be renewed. let expiry = now + 1800; let regenerate = cookie.timestamp() >= expiry; @@ -149,6 +158,9 @@ impl CookieLayer { || rand::thread_rng().gen_bool(0.05) { // The request is allowed to go through. + trace!(target: "CookieLayer", + "Allowing request {} regardless of missing/invalid server cookie", + exchange.request.id); let metadata = match cookie { Some(&cookie) => CookieMetadata::ClientCookie(cookie), None => CookieMetadata::None, @@ -158,6 +170,9 @@ impl CookieLayer { } // Block the request. + trace!(target: "CookieLayer", + "Blocking request {} due to missing/invalid server cookie", + exchange.request.id); if exchange.request.has_edns() { exchange.respond(ResponseCode::BadCookie); } else { @@ -246,14 +261,23 @@ impl ServiceLayer for CookieLayer { // The request had a client cookie (and possibly an invalid server // cookie). Generate a new server cookie and include it. Some(CookieMetadata::ClientCookie(cookie)) => { + trace!(target: "CookieLayer", + "Generating cookie for response {}", + response.response.id); self.generate_cookie(addr, *cookie) } // The request had a server cookie that may need to be renewed. Some(CookieMetadata::ServerCookie { cookie, regenerate }) => { if *regenerate { + trace!(target: "CookieLayer", + "Refreshing cookie for response {}", + response.response.id); self.generate_cookie(addr, *cookie.request()) } else { + trace!(target: "CookieLayer", + "Using existing server cookie for response {}", + response.response.id); cookie.clone() } } @@ -367,6 +391,7 @@ impl CookieSecrets { //----------- CookieMetadata ------------------------------------------------- /// Information about a DNS request's use of cookies. +#[derive(Clone, Debug)] pub enum CookieMetadata { /// The request did not use DNS cookies. None, diff --git a/src/new_server/transport/mod.rs b/src/new_server/transport/mod.rs index 8cb7012d0..5aae60dc4 100644 --- a/src/new_server/transport/mod.rs +++ b/src/new_server/transport/mod.rs @@ -4,6 +4,7 @@ use core::net::{IpAddr, SocketAddr}; use std::{io, sync::Arc, time::SystemTime, vec::Vec}; use bumpalo::Bump; +use log::trace; use tokio::net::UdpSocket; use crate::{ @@ -70,17 +71,23 @@ pub async fn serve_udp( metadata: vec![Metadata::new(SourceIpAddr(peer.ip()))], }; + trace!(target: "serve_udp", + "Received request {} from {peer}", + exchange.request.id); + // Generate the appropriate response. self.service.respond(&mut exchange).await; + trace!(target: "serve_udp", + "Sending response {} to {peer}", + exchange.response.id); + // Build up the response message. - println!("Returning response: {:?}", exchange.response); let mut buffer = vec![0u8; 65536]; let message = exchange.response.build(&mut buffer).unwrap_or_else(|_| { todo!("how to handle truncation errors?") }); - println!("In bytes: {:?}", message.as_bytes()); // Send the response back to the peer. let _ = self.socket.send_to(message.as_bytes(), peer).await; From 1c55f0441d60686e82875cc50ae51111363c2567 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 7 Feb 2025 10:29:39 +0100 Subject: [PATCH 137/167] Add example program 'new-server' --- Cargo.lock | 107 +++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 5 ++ examples/new-server.rs | 92 +++++++++++++++++++++++++++++++++++ 3 files changed, 204 insertions(+) create mode 100644 examples/new-server.rs diff --git a/Cargo.lock b/Cargo.lock index ce92a567a..b18b3a819 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,6 +47,56 @@ dependencies = [ "libc", ] +[[package]] +name = "anstream" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" +dependencies = [ + "anstyle", + "once_cell", + "windows-sys 0.59.0", +] + [[package]] name = "arbitrary" version = "1.4.1" @@ -167,6 +217,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "colorchoice" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -236,6 +292,7 @@ dependencies = [ "bytes", "chrono", "domain-macros", + "env_logger", "futures-util", "hashbrown 0.14.5", "heapless", @@ -288,6 +345,29 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "env_filter" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "humantime", + "log", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -495,6 +575,12 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "iana-time-zone" version = "0.1.61" @@ -528,6 +614,12 @@ dependencies = [ "hashbrown 0.15.2", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + [[package]] name = "itertools" version = "0.13.0" @@ -1441,6 +1533,12 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" version = "1.12.1" @@ -1648,6 +1746,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.48.5" diff --git a/Cargo.toml b/Cargo.toml index c934b8215..da0deeebb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -100,6 +100,7 @@ tokio-rustls = { version = "0.26", default-features = false, features = [ tokio-test = "0.4" tokio-tfo = { version = "0.2.0" } webpki-roots = { version = "0.26" } +env_logger = { version = "0.11" } # For the "mysql-zone" example #sqlx = { version = "0.6", features = [ "runtime-tokio-native-tls", "mysql" ] } @@ -119,6 +120,10 @@ required-features = ["resolv"] name = "lookup" required-features = ["resolv"] +[[example]] +name = "new-server" +required-features = ["net", "unstable-server-transport", "unstable-client-transport"] + [[example]] name = "resolv-sync" required-features = ["resolv-sync"] diff --git a/examples/new-server.rs b/examples/new-server.rs new file mode 100644 index 000000000..27b4b643a --- /dev/null +++ b/examples/new-server.rs @@ -0,0 +1,92 @@ +use std::ops::ControlFlow; + +use log::trace; + +use domain::new_server::{ + exchange::{OutgoingResponse, ResponseCode}, + layers::{ + cookie::{CookieMetadata, CookiePolicy, CookieSecrets}, + CookieLayer, + }, + transport, Exchange, LocalService, LocalServiceLayer, Service, + ServiceLayer, +}; + +pub struct MyService; + +impl Service for MyService { + async fn respond(&self, exchange: &mut Exchange<'_>) { + let cookie = exchange + .metadata + .iter() + .find_map(|m| m.try_as::()); + + if let Some(CookieMetadata::ServerCookie { .. }) = cookie { + trace!(target: "MyService", "Request had a valid cookie"); + } else { + trace!(target: "MyService", "Request did not have a valid cookie"); + } + + exchange.respond(ResponseCode::Success); + + // Copy all questions from the request to the response. + exchange + .response + .questions + .append(&mut exchange.request.questions); + } +} + +impl LocalService for MyService { + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + self.respond(exchange).await + } +} + +pub struct MyLayer; + +impl ServiceLayer for MyLayer { + async fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + trace!(target: "MyLayer", + "Incoming request (message ID {})", + exchange.request.id); + ControlFlow::Continue(()) + } + + async fn process_outgoing(&self, response: OutgoingResponse<'_, '_>) { + trace!(target: "MyLayer", + "Outgoing response (message ID {})", + response.response.id); + } +} + +impl LocalServiceLayer for MyLayer { + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + self.process_incoming(exchange).await + } + + async fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) { + self.process_outgoing(response).await + } +} + +#[tokio::main] +async fn main() { + env_logger::init(); + + let addr = "127.0.0.1:8080".parse().unwrap(); + let cookie_layer = + CookieLayer::new(CookiePolicy::default(), CookieSecrets::generate()); + let service = (MyLayer, cookie_layer, MyService); + let result = transport::serve_udp(addr, service).await; + println!("Ended on result {result:?}"); +} From e833f81ff4606bb31389ffd952a0f90945fca10f Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 14 Feb 2025 14:49:35 +0100 Subject: [PATCH 138/167] [new_server] Implement 'MinAnyLayer' --- src/new_base/charstr.rs | 8 +++ src/new_base/question.rs | 4 ++ src/new_server/layers/min_any.rs | 85 ++++++++++++++++++++++++++++++++ src/new_server/layers/mod.rs | 3 ++ 4 files changed, 100 insertions(+) create mode 100644 src/new_server/layers/min_any.rs diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index adf8e6024..e3c704930 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -17,6 +17,14 @@ pub struct CharStr { pub octets: [u8], } +//--- Associated constants + +impl CharStr { + /// A zero-length [`CharStr`]. + pub const EMPTY: &'static Self = + unsafe { core::mem::transmute(&[0u8] as &[u8]) }; +} + //--- Parsing from DNS messages impl<'a> SplitMessageBytes<'a> for &'a CharStr { diff --git a/src/new_base/question.rs b/src/new_base/question.rs index 82ea76f25..e91873c38 100644 --- a/src/new_base/question.rs +++ b/src/new_base/question.rs @@ -167,6 +167,9 @@ impl QType { /// The type of an [`Aaaa`](crate::new_rdata::Aaaa) record. pub const AAAA: Self = Self::new(28); + + /// All possible records. + pub const ANY: Self = Self::new(255); } //--- Formatting @@ -184,6 +187,7 @@ impl fmt::Debug for QType { Self::MX => "QType::MX", Self::TXT => "QType::TXT", Self::AAAA => "QType::AAAA", + Self::ANY => "QType::ANY", _ => return write!(f, "QType({})", self.code), }) } diff --git a/src/new_server/layers/min_any.rs b/src/new_server/layers/min_any.rs new file mode 100644 index 000000000..9a787a99d --- /dev/null +++ b/src/new_server/layers/min_any.rs @@ -0,0 +1,85 @@ +//! Providing minimal responses to ANY queries. +//! +//! See [RFC 8482](https://datatracker.ietf.org/doc/html/rfc8482) for more. + +use core::ops::ControlFlow; + +use crate::{ + new_base::{ + wire::ParseBytes, CharStr, QClass, QType, Question, RClass, RType, + Record, TTL, + }, + new_rdata::{HInfo, RecordData}, + new_server::{ + exchange::{OutgoingResponse, ResponseCode}, + Exchange, LocalServiceLayer, ServiceLayer, + }, +}; + +/// A simple responder for ANY queries. +/// +/// Conventionally, queries with `QTYPE=ANY` would result in large responses +/// containing all the records for a particular name. While they have some +/// legitimate use cases, they are quite rare, and they can be abused towards +/// denial-of-service attacks. +/// +/// In the spirit of [RFC 8482], this service layer responds to `QTYPE=ANY` +/// queries (regardless of the queried name or class) with a short hardcoded +/// response (specifically, a fake `HINFO` record with the `CPU` string set to +/// `RFC8482`, and the `OS` string empty). +/// +/// [RFC 8482]: https://datatracker.ietf.org/doc/html/rfc8482 +pub struct MinAnyLayer; + +impl ServiceLayer for MinAnyLayer { + async fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + match exchange.request.questions.as_slice() { + [Question { + qname, + qtype: QType::ANY, + qclass: QClass::IN, + }] => { + let record = Record { + rname: *qname, + rtype: RType::HINFO, + rclass: RClass::IN, + ttl: TTL::from(3600), + rdata: RecordData::HInfo(HInfo { + cpu: <&CharStr>::parse_bytes(b"\x07RFC8482").unwrap(), + os: CharStr::EMPTY, + }), + }; + exchange.respond(ResponseCode::Success); + exchange.response.answers.push(record); + ControlFlow::Break(()) + } + + _ => ControlFlow::Continue(()), + } + } + + async fn process_outgoing(&self, _response: OutgoingResponse<'_, '_>) { + // A later layer caught the request and built a response to it. That + // means that the request wasn't a QTYPE=ANY query, so we don't have + // to do anything here. + } +} + +impl LocalServiceLayer for MinAnyLayer { + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + self.process_incoming(exchange).await + } + + async fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) { + self.process_outgoing(response).await + } +} diff --git a/src/new_server/layers/mod.rs b/src/new_server/layers/mod.rs index 885e64001..7d0d9bfb9 100644 --- a/src/new_server/layers/mod.rs +++ b/src/new_server/layers/mod.rs @@ -2,3 +2,6 @@ pub mod cookie; pub use cookie::CookieLayer; + +mod min_any; +pub use min_any::MinAnyLayer; From c6fc39f1491697bbc501a23321f07389ebe5fa46 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 28 Feb 2025 12:46:21 +0100 Subject: [PATCH 139/167] [new_rdata/basic] Refactor into submodules --- src/new_rdata/basic.rs | 518 ----------------------------------- src/new_rdata/basic/a.rs | 77 ++++++ src/new_rdata/basic/cname.rs | 48 ++++ src/new_rdata/basic/hinfo.rs | 44 +++ src/new_rdata/basic/mod.rs | 30 ++ src/new_rdata/basic/mx.rs | 62 +++++ src/new_rdata/basic/ns.rs | 48 ++++ src/new_rdata/basic/ptr.rs | 48 ++++ src/new_rdata/basic/soa.rs | 90 ++++++ src/new_rdata/basic/txt.rs | 97 +++++++ src/new_rdata/basic/wks.rs | 62 +++++ 11 files changed, 606 insertions(+), 518 deletions(-) delete mode 100644 src/new_rdata/basic.rs create mode 100644 src/new_rdata/basic/a.rs create mode 100644 src/new_rdata/basic/cname.rs create mode 100644 src/new_rdata/basic/hinfo.rs create mode 100644 src/new_rdata/basic/mod.rs create mode 100644 src/new_rdata/basic/mx.rs create mode 100644 src/new_rdata/basic/ns.rs create mode 100644 src/new_rdata/basic/ptr.rs create mode 100644 src/new_rdata/basic/soa.rs create mode 100644 src/new_rdata/basic/txt.rs create mode 100644 src/new_rdata/basic/wks.rs diff --git a/src/new_rdata/basic.rs b/src/new_rdata/basic.rs deleted file mode 100644 index e880721da..000000000 --- a/src/new_rdata/basic.rs +++ /dev/null @@ -1,518 +0,0 @@ -//! Core record data types. -//! -//! See [RFC 1035](https://datatracker.ietf.org/doc/html/rfc1035). - -use core::fmt; -use core::net::Ipv4Addr; -use core::str::FromStr; - -use domain_macros::*; - -use crate::new_base::{ - build::{self, BuildIntoMessage, BuildResult}, - parse::{ParseMessageBytes, SplitMessageBytes}, - wire::{AsBytes, ParseBytes, ParseError, SplitBytes, U16, U32}, - CharStr, Serial, -}; - -//----------- A -------------------------------------------------------------- - -/// The IPv4 address of a host responsible for this domain. -#[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - AsBytes, - BuildBytes, - ParseBytes, - ParseBytesByRef, - SplitBytes, - SplitBytesByRef, -)] -#[repr(transparent)] -pub struct A { - /// The IPv4 address octets. - pub octets: [u8; 4], -} - -//--- Converting to and from 'Ipv4Addr' - -impl From for A { - fn from(value: Ipv4Addr) -> Self { - Self { - octets: value.octets(), - } - } -} - -impl From for Ipv4Addr { - fn from(value: A) -> Self { - Self::from(value.octets) - } -} - -//--- Parsing from a string - -impl FromStr for A { - type Err = ::Err; - - fn from_str(s: &str) -> Result { - Ipv4Addr::from_str(s).map(A::from) - } -} - -//--- Formatting - -impl fmt::Display for A { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - Ipv4Addr::from(*self).fmt(f) - } -} - -//--- Building into DNS messages - -impl BuildIntoMessage for A { - fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { - self.as_bytes().build_into_message(builder) - } -} - -//----------- Ns ------------------------------------------------------------- - -/// The authoritative name server for this domain. -#[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - BuildBytes, - ParseBytes, - SplitBytes, -)] -#[repr(transparent)] -pub struct Ns { - /// The name of the authoritative server. - pub name: N, -} - -//--- Parsing from DNS messages - -impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Ns { - fn parse_message_bytes( - contents: &'a [u8], - start: usize, - ) -> Result { - N::parse_message_bytes(contents, start).map(|name| Self { name }) - } -} - -//--- Building into DNS messages - -impl BuildIntoMessage for Ns { - fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { - self.name.build_into_message(builder) - } -} - -//----------- Cname ---------------------------------------------------------- - -/// The canonical name for this domain. -#[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - BuildBytes, - ParseBytes, - SplitBytes, -)] -#[repr(transparent)] -pub struct CName { - /// The canonical name. - pub name: N, -} - -//--- Parsing from DNS messages - -impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for CName { - fn parse_message_bytes( - contents: &'a [u8], - start: usize, - ) -> Result { - N::parse_message_bytes(contents, start).map(|name| Self { name }) - } -} - -//--- Building into DNS messages - -impl BuildIntoMessage for CName { - fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { - self.name.build_into_message(builder) - } -} - -//----------- Soa ------------------------------------------------------------ - -/// The start of a zone of authority. -#[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - Hash, - BuildBytes, - ParseBytes, - SplitBytes, -)] -pub struct Soa { - /// The name server which provided this zone. - pub mname: N, - - /// The mailbox of the maintainer of this zone. - pub rname: N, - - /// The version number of the original copy of this zone. - pub serial: Serial, - - /// The number of seconds to wait until refreshing the zone. - pub refresh: U32, - - /// The number of seconds to wait until retrying a failed refresh. - pub retry: U32, - - /// The number of seconds until the zone is considered expired. - pub expire: U32, - - /// The minimum TTL for any record in this zone. - pub minimum: U32, -} - -//--- Parsing from DNS messages - -impl<'a, N: SplitMessageBytes<'a>> ParseMessageBytes<'a> for Soa { - fn parse_message_bytes( - contents: &'a [u8], - start: usize, - ) -> Result { - let (mname, rest) = N::split_message_bytes(contents, start)?; - let (rname, rest) = N::split_message_bytes(contents, rest)?; - let (&serial, rest) = <&Serial>::split_message_bytes(contents, rest)?; - let (&refresh, rest) = <&U32>::split_message_bytes(contents, rest)?; - let (&retry, rest) = <&U32>::split_message_bytes(contents, rest)?; - let (&expire, rest) = <&U32>::split_message_bytes(contents, rest)?; - let &minimum = <&U32>::parse_message_bytes(contents, rest)?; - - Ok(Self { - mname, - rname, - serial, - refresh, - retry, - expire, - minimum, - }) - } -} - -//--- Building into DNS messages - -impl BuildIntoMessage for Soa { - fn build_into_message( - &self, - mut builder: build::Builder<'_>, - ) -> BuildResult { - self.mname.build_into_message(builder.delegate())?; - self.rname.build_into_message(builder.delegate())?; - builder.append_bytes(self.serial.as_bytes())?; - builder.append_bytes(self.refresh.as_bytes())?; - builder.append_bytes(self.retry.as_bytes())?; - builder.append_bytes(self.expire.as_bytes())?; - builder.append_bytes(self.minimum.as_bytes())?; - Ok(builder.commit()) - } -} - -//----------- Wks ------------------------------------------------------------ - -/// Well-known services supported on this domain. -#[derive(AsBytes, BuildBytes, ParseBytesByRef)] -#[repr(C, packed)] -pub struct Wks { - /// The address of the host providing these services. - pub address: A, - - /// The IP protocol number for the services (e.g. TCP). - pub protocol: u8, - - /// A bitset of supported well-known ports. - pub ports: [u8], -} - -//--- Formatting - -impl fmt::Debug for Wks { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - struct Ports<'a>(&'a [u8]); - - impl fmt::Debug for Ports<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let entries = self - .0 - .iter() - .enumerate() - .flat_map(|(i, &b)| (0..8).map(move |j| (i, j, b))) - .filter(|(_, j, b)| b & (1 << j) != 0) - .map(|(i, j, _)| i * 8 + j); - - f.debug_set().entries(entries).finish() - } - } - - f.debug_struct("Wks") - .field("address", &format_args!("{}", self.address)) - .field("protocol", &self.protocol) - .field("ports", &Ports(&self.ports)) - .finish() - } -} - -//--- Building into DNS messages - -impl BuildIntoMessage for Wks { - fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { - self.as_bytes().build_into_message(builder) - } -} - -//----------- Ptr ------------------------------------------------------------ - -/// A pointer to another domain name. -#[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - BuildBytes, - ParseBytes, - SplitBytes, -)] -#[repr(transparent)] -pub struct Ptr { - /// The referenced domain name. - pub name: N, -} - -//--- Parsing from DNS messages - -impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Ptr { - fn parse_message_bytes( - contents: &'a [u8], - start: usize, - ) -> Result { - N::parse_message_bytes(contents, start).map(|name| Self { name }) - } -} - -//--- Building into DNS messages - -impl BuildIntoMessage for Ptr { - fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { - self.name.build_into_message(builder) - } -} - -//----------- HInfo ---------------------------------------------------------- - -/// Information about the host computer. -#[derive(Clone, Debug, PartialEq, Eq, BuildBytes, ParseBytes, SplitBytes)] -pub struct HInfo<'a> { - /// The CPU type. - pub cpu: &'a CharStr, - - /// The OS type. - pub os: &'a CharStr, -} - -//--- Parsing from DNS messages - -impl<'a> ParseMessageBytes<'a> for HInfo<'a> { - fn parse_message_bytes( - contents: &'a [u8], - start: usize, - ) -> Result { - Self::parse_bytes(&contents[start..]) - } -} - -//--- Building into DNS messages - -impl BuildIntoMessage for HInfo<'_> { - fn build_into_message( - &self, - mut builder: build::Builder<'_>, - ) -> BuildResult { - self.cpu.build_into_message(builder.delegate())?; - self.os.build_into_message(builder.delegate())?; - Ok(builder.commit()) - } -} - -//----------- Mx ------------------------------------------------------------- - -/// A host that can exchange mail for this domain. -#[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - BuildBytes, - ParseBytes, - SplitBytes, -)] -#[repr(C)] -pub struct Mx { - /// The preference for this host over others. - pub preference: U16, - - /// The domain name of the mail exchanger. - pub exchange: N, -} - -//--- Parsing from DNS messages - -impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Mx { - fn parse_message_bytes( - contents: &'a [u8], - start: usize, - ) -> Result { - let (&preference, rest) = - <&U16>::split_message_bytes(contents, start)?; - let exchange = N::parse_message_bytes(contents, rest)?; - Ok(Self { - preference, - exchange, - }) - } -} - -//--- Building into DNS messages - -impl BuildIntoMessage for Mx { - fn build_into_message( - &self, - mut builder: build::Builder<'_>, - ) -> BuildResult { - builder.append_bytes(self.preference.as_bytes())?; - self.exchange.build_into_message(builder.delegate())?; - Ok(builder.commit()) - } -} - -//----------- Txt ------------------------------------------------------------ - -/// Free-form text strings about this domain. -#[derive(AsBytes, BuildBytes)] -#[repr(transparent)] -pub struct Txt { - /// The text strings, as concatenated [`CharStr`]s. - /// - /// The [`CharStr`]s begin with a length octet so they can be separated. - content: [u8], -} - -//--- Interaction - -impl Txt { - /// Iterate over the [`CharStr`]s in this record. - pub fn iter( - &self, - ) -> impl Iterator> + '_ { - // NOTE: A TXT record always has at least one 'CharStr' within. - let first = <&CharStr>::split_bytes(&self.content); - core::iter::successors(Some(first), |prev| { - prev.as_ref() - .ok() - .map(|(_elem, rest)| <&CharStr>::split_bytes(rest)) - }) - .map(|result| result.map(|(elem, _rest)| elem)) - } -} - -//--- Parsing from DNS messages - -impl<'a> ParseMessageBytes<'a> for &'a Txt { - fn parse_message_bytes( - contents: &'a [u8], - start: usize, - ) -> Result { - Self::parse_bytes(&contents[start..]) - } -} - -//--- Building into DNS messages - -impl BuildIntoMessage for Txt { - fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { - self.content.build_into_message(builder) - } -} - -//--- Parsing from bytes - -impl<'a> ParseBytes<'a> for &'a Txt { - fn parse_bytes(bytes: &'a [u8]) -> Result { - // NOTE: The input must contain at least one 'CharStr'. - let (_, mut rest) = <&CharStr>::split_bytes(bytes)?; - while !rest.is_empty() { - (_, rest) = <&CharStr>::split_bytes(rest)?; - } - - // SAFETY: 'Txt' is 'repr(transparent)' to '[u8]'. - Ok(unsafe { core::mem::transmute::<&'a [u8], Self>(bytes) }) - } -} - -//--- Formatting - -impl fmt::Debug for Txt { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - struct Content<'a>(&'a Txt); - impl fmt::Debug for Content<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut list = f.debug_list(); - for elem in self.0.iter() { - if let Ok(elem) = elem { - list.entry(&elem); - } else { - list.entry(&ParseError); - } - } - list.finish() - } - } - - f.debug_tuple("Txt").field(&Content(self)).finish() - } -} diff --git a/src/new_rdata/basic/a.rs b/src/new_rdata/basic/a.rs new file mode 100644 index 000000000..3f8d8b2aa --- /dev/null +++ b/src/new_rdata/basic/a.rs @@ -0,0 +1,77 @@ +use core::fmt; +use core::net::Ipv4Addr; +use core::str::FromStr; + +use domain_macros::*; + +use crate::new_base::{ + build::{self, BuildIntoMessage, BuildResult}, + wire::AsBytes, +}; + +//----------- A -------------------------------------------------------------- + +/// The IPv4 address of a host responsible for this domain. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + AsBytes, + BuildBytes, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, +)] +#[repr(transparent)] +pub struct A { + /// The IPv4 address octets. + pub octets: [u8; 4], +} + +//--- Converting to and from 'Ipv4Addr' + +impl From for A { + fn from(value: Ipv4Addr) -> Self { + Self { + octets: value.octets(), + } + } +} + +impl From for Ipv4Addr { + fn from(value: A) -> Self { + Self::from(value.octets) + } +} + +//--- Parsing from a string + +impl FromStr for A { + type Err = ::Err; + + fn from_str(s: &str) -> Result { + Ipv4Addr::from_str(s).map(A::from) + } +} + +//--- Formatting + +impl fmt::Display for A { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Ipv4Addr::from(*self).fmt(f) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for A { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { + self.as_bytes().build_into_message(builder) + } +} diff --git a/src/new_rdata/basic/cname.rs b/src/new_rdata/basic/cname.rs new file mode 100644 index 000000000..c048a1019 --- /dev/null +++ b/src/new_rdata/basic/cname.rs @@ -0,0 +1,48 @@ +use domain_macros::*; + +use crate::new_base::{ + build::{self, BuildIntoMessage, BuildResult}, + parse::ParseMessageBytes, + wire::ParseError, +}; + +//----------- CName ---------------------------------------------------------- + +/// The canonical name for this domain. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + BuildBytes, + ParseBytes, + SplitBytes, +)] +#[repr(transparent)] +pub struct CName { + /// The canonical name. + pub name: N, +} + +//--- Parsing from DNS messages + +impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for CName { + fn parse_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result { + N::parse_message_bytes(contents, start).map(|name| Self { name }) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for CName { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { + self.name.build_into_message(builder) + } +} diff --git a/src/new_rdata/basic/hinfo.rs b/src/new_rdata/basic/hinfo.rs new file mode 100644 index 000000000..524653f7a --- /dev/null +++ b/src/new_rdata/basic/hinfo.rs @@ -0,0 +1,44 @@ +use domain_macros::*; + +use crate::new_base::{ + build::{self, BuildIntoMessage, BuildResult}, + parse::ParseMessageBytes, + wire::{ParseBytes, ParseError}, + CharStr, +}; + +//----------- HInfo ---------------------------------------------------------- + +/// Information about the host computer. +#[derive(Clone, Debug, PartialEq, Eq, BuildBytes, ParseBytes, SplitBytes)] +pub struct HInfo<'a> { + /// The CPU type. + pub cpu: &'a CharStr, + + /// The OS type. + pub os: &'a CharStr, +} + +//--- Parsing from DNS messages + +impl<'a> ParseMessageBytes<'a> for HInfo<'a> { + fn parse_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result { + Self::parse_bytes(&contents[start..]) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for HInfo<'_> { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> BuildResult { + self.cpu.build_into_message(builder.delegate())?; + self.os.build_into_message(builder.delegate())?; + Ok(builder.commit()) + } +} diff --git a/src/new_rdata/basic/mod.rs b/src/new_rdata/basic/mod.rs new file mode 100644 index 000000000..466873011 --- /dev/null +++ b/src/new_rdata/basic/mod.rs @@ -0,0 +1,30 @@ +//! Core record data types. +//! +//! See [RFC 1035](https://datatracker.ietf.org/doc/html/rfc1035). + +mod a; +pub use a::A; + +mod ns; +pub use ns::Ns; + +mod cname; +pub use cname::CName; + +mod soa; +pub use soa::Soa; + +mod wks; +pub use wks::Wks; + +mod ptr; +pub use ptr::Ptr; + +mod hinfo; +pub use hinfo::HInfo; + +mod mx; +pub use mx::Mx; + +mod txt; +pub use txt::Txt; diff --git a/src/new_rdata/basic/mx.rs b/src/new_rdata/basic/mx.rs new file mode 100644 index 000000000..af3f932df --- /dev/null +++ b/src/new_rdata/basic/mx.rs @@ -0,0 +1,62 @@ +use domain_macros::*; + +use crate::new_base::{ + build::{self, BuildIntoMessage, BuildResult}, + parse::{ParseMessageBytes, SplitMessageBytes}, + wire::{AsBytes, ParseError, U16}, +}; + +//----------- Mx ------------------------------------------------------------- + +/// A host that can exchange mail for this domain. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + BuildBytes, + ParseBytes, + SplitBytes, +)] +#[repr(C)] +pub struct Mx { + /// The preference for this host over others. + pub preference: U16, + + /// The domain name of the mail exchanger. + pub exchange: N, +} + +//--- Parsing from DNS messages + +impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Mx { + fn parse_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result { + let (&preference, rest) = + <&U16>::split_message_bytes(contents, start)?; + let exchange = N::parse_message_bytes(contents, rest)?; + Ok(Self { + preference, + exchange, + }) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for Mx { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> BuildResult { + builder.append_bytes(self.preference.as_bytes())?; + self.exchange.build_into_message(builder.delegate())?; + Ok(builder.commit()) + } +} diff --git a/src/new_rdata/basic/ns.rs b/src/new_rdata/basic/ns.rs new file mode 100644 index 000000000..ac0deea68 --- /dev/null +++ b/src/new_rdata/basic/ns.rs @@ -0,0 +1,48 @@ +use domain_macros::*; + +use crate::new_base::{ + build::{self, BuildIntoMessage, BuildResult}, + parse::ParseMessageBytes, + wire::ParseError, +}; + +//----------- Ns ------------------------------------------------------------- + +/// The authoritative name server for this domain. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + BuildBytes, + ParseBytes, + SplitBytes, +)] +#[repr(transparent)] +pub struct Ns { + /// The name of the authoritative server. + pub name: N, +} + +//--- Parsing from DNS messages + +impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Ns { + fn parse_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result { + N::parse_message_bytes(contents, start).map(|name| Self { name }) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for Ns { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { + self.name.build_into_message(builder) + } +} diff --git a/src/new_rdata/basic/ptr.rs b/src/new_rdata/basic/ptr.rs new file mode 100644 index 000000000..e4eea4a8c --- /dev/null +++ b/src/new_rdata/basic/ptr.rs @@ -0,0 +1,48 @@ +use domain_macros::*; + +use crate::new_base::{ + build::{self, BuildIntoMessage, BuildResult}, + parse::ParseMessageBytes, + wire::ParseError, +}; + +//----------- Ptr ------------------------------------------------------------ + +/// A pointer to another domain name. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + BuildBytes, + ParseBytes, + SplitBytes, +)] +#[repr(transparent)] +pub struct Ptr { + /// The referenced domain name. + pub name: N, +} + +//--- Parsing from DNS messages + +impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Ptr { + fn parse_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result { + N::parse_message_bytes(contents, start).map(|name| Self { name }) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for Ptr { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { + self.name.build_into_message(builder) + } +} diff --git a/src/new_rdata/basic/soa.rs b/src/new_rdata/basic/soa.rs new file mode 100644 index 000000000..f9d236bf4 --- /dev/null +++ b/src/new_rdata/basic/soa.rs @@ -0,0 +1,90 @@ +use domain_macros::*; + +use crate::new_base::{ + build::{self, BuildIntoMessage, BuildResult}, + parse::{ParseMessageBytes, SplitMessageBytes}, + wire::{AsBytes, ParseError, U32}, + Serial, +}; + +//----------- Soa ------------------------------------------------------------ + +/// The start of a zone of authority. +#[derive( + Copy, + Clone, + Debug, + PartialEq, + Eq, + Hash, + BuildBytes, + ParseBytes, + SplitBytes, +)] +pub struct Soa { + /// The name server which provided this zone. + pub mname: N, + + /// The mailbox of the maintainer of this zone. + pub rname: N, + + /// The version number of the original copy of this zone. + pub serial: Serial, + + /// The number of seconds to wait until refreshing the zone. + pub refresh: U32, + + /// The number of seconds to wait until retrying a failed refresh. + pub retry: U32, + + /// The number of seconds until the zone is considered expired. + pub expire: U32, + + /// The minimum TTL for any record in this zone. + pub minimum: U32, +} + +//--- Parsing from DNS messages + +impl<'a, N: SplitMessageBytes<'a>> ParseMessageBytes<'a> for Soa { + fn parse_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result { + let (mname, rest) = N::split_message_bytes(contents, start)?; + let (rname, rest) = N::split_message_bytes(contents, rest)?; + let (&serial, rest) = <&Serial>::split_message_bytes(contents, rest)?; + let (&refresh, rest) = <&U32>::split_message_bytes(contents, rest)?; + let (&retry, rest) = <&U32>::split_message_bytes(contents, rest)?; + let (&expire, rest) = <&U32>::split_message_bytes(contents, rest)?; + let &minimum = <&U32>::parse_message_bytes(contents, rest)?; + + Ok(Self { + mname, + rname, + serial, + refresh, + retry, + expire, + minimum, + }) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for Soa { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> BuildResult { + self.mname.build_into_message(builder.delegate())?; + self.rname.build_into_message(builder.delegate())?; + builder.append_bytes(self.serial.as_bytes())?; + builder.append_bytes(self.refresh.as_bytes())?; + builder.append_bytes(self.retry.as_bytes())?; + builder.append_bytes(self.expire.as_bytes())?; + builder.append_bytes(self.minimum.as_bytes())?; + Ok(builder.commit()) + } +} diff --git a/src/new_rdata/basic/txt.rs b/src/new_rdata/basic/txt.rs new file mode 100644 index 000000000..f659a131e --- /dev/null +++ b/src/new_rdata/basic/txt.rs @@ -0,0 +1,97 @@ +use core::fmt; + +use domain_macros::*; + +use crate::new_base::{ + build::{self, BuildIntoMessage, BuildResult}, + parse::ParseMessageBytes, + wire::{ParseBytes, ParseError, SplitBytes}, + CharStr, +}; + +//----------- Txt ------------------------------------------------------------ + +/// Free-form text strings about this domain. +#[derive(AsBytes, BuildBytes)] +#[repr(transparent)] +pub struct Txt { + /// The text strings, as concatenated [`CharStr`]s. + /// + /// The [`CharStr`]s begin with a length octet so they can be separated. + content: [u8], +} + +//--- Interaction + +impl Txt { + /// Iterate over the [`CharStr`]s in this record. + pub fn iter( + &self, + ) -> impl Iterator> + '_ { + // NOTE: A TXT record always has at least one 'CharStr' within. + let first = <&CharStr>::split_bytes(&self.content); + core::iter::successors(Some(first), |prev| { + prev.as_ref() + .ok() + .map(|(_elem, rest)| <&CharStr>::split_bytes(rest)) + }) + .map(|result| result.map(|(elem, _rest)| elem)) + } +} + +//--- Parsing from DNS messages + +impl<'a> ParseMessageBytes<'a> for &'a Txt { + fn parse_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result { + Self::parse_bytes(&contents[start..]) + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for Txt { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { + self.content.build_into_message(builder) + } +} + +//--- Parsing from bytes + +impl<'a> ParseBytes<'a> for &'a Txt { + fn parse_bytes(bytes: &'a [u8]) -> Result { + // NOTE: The input must contain at least one 'CharStr'. + let (_, mut rest) = <&CharStr>::split_bytes(bytes)?; + while !rest.is_empty() { + (_, rest) = <&CharStr>::split_bytes(rest)?; + } + + // SAFETY: 'Txt' is 'repr(transparent)' to '[u8]'. + Ok(unsafe { core::mem::transmute::<&'a [u8], Self>(bytes) }) + } +} + +//--- Formatting + +impl fmt::Debug for Txt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Content<'a>(&'a Txt); + impl fmt::Debug for Content<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut list = f.debug_list(); + for elem in self.0.iter() { + if let Ok(elem) = elem { + list.entry(&elem); + } else { + list.entry(&ParseError); + } + } + list.finish() + } + } + + f.debug_tuple("Txt").field(&Content(self)).finish() + } +} diff --git a/src/new_rdata/basic/wks.rs b/src/new_rdata/basic/wks.rs new file mode 100644 index 000000000..a02f5d82d --- /dev/null +++ b/src/new_rdata/basic/wks.rs @@ -0,0 +1,62 @@ +use core::fmt; + +use domain_macros::*; + +use crate::new_base::{ + build::{self, BuildIntoMessage, BuildResult}, + wire::AsBytes, +}; + +use super::A; + +//----------- Wks ------------------------------------------------------------ + +/// Well-known services supported on this domain. +#[derive(AsBytes, BuildBytes, ParseBytesByRef)] +#[repr(C, packed)] +pub struct Wks { + /// The address of the host providing these services. + pub address: A, + + /// The IP protocol number for the services (e.g. TCP). + pub protocol: u8, + + /// A bitset of supported well-known ports. + pub ports: [u8], +} + +//--- Formatting + +impl fmt::Debug for Wks { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Ports<'a>(&'a [u8]); + + impl fmt::Debug for Ports<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let entries = self + .0 + .iter() + .enumerate() + .flat_map(|(i, &b)| (0..8).map(move |j| (i, j, b))) + .filter(|(_, j, b)| b & (1 << j) != 0) + .map(|(i, j, _)| i * 8 + j); + + f.debug_set().entries(entries).finish() + } + } + + f.debug_struct("Wks") + .field("address", &format_args!("{}", self.address)) + .field("protocol", &self.protocol) + .field("ports", &Ports(&self.ports)) + .finish() + } +} + +//--- Building into DNS messages + +impl BuildIntoMessage for Wks { + fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { + self.as_bytes().build_into_message(builder) + } +} From d0045deec0fccfaeb2d2c567c47e683ff553bfb3 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 3 Mar 2025 12:29:41 +0100 Subject: [PATCH 140/167] [new_rdata] Impl equality for 'RecordData' --- src/new_rdata/basic/txt.rs | 38 ++++++++++++++++++++------------------ src/new_rdata/basic/wks.rs | 25 +++++++++++++++++++++++++ src/new_rdata/mod.rs | 4 ++-- 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/src/new_rdata/basic/txt.rs b/src/new_rdata/basic/txt.rs index f659a131e..a7188e9b0 100644 --- a/src/new_rdata/basic/txt.rs +++ b/src/new_rdata/basic/txt.rs @@ -25,17 +25,17 @@ pub struct Txt { impl Txt { /// Iterate over the [`CharStr`]s in this record. - pub fn iter( - &self, - ) -> impl Iterator> + '_ { + pub fn iter(&self) -> impl Iterator + '_ { // NOTE: A TXT record always has at least one 'CharStr' within. - let first = <&CharStr>::split_bytes(&self.content); - core::iter::successors(Some(first), |prev| { - prev.as_ref() - .ok() - .map(|(_elem, rest)| <&CharStr>::split_bytes(rest)) + let first = <&CharStr>::split_bytes(&self.content) + .expect("'Txt' records always contain valid 'CharStr's"); + core::iter::successors(Some(first), |(_, rest)| { + (!rest.is_empty()).then(|| { + <&CharStr>::split_bytes(rest) + .expect("'Txt' records always contain valid 'CharStr's") + }) }) - .map(|result| result.map(|(elem, _rest)| elem)) + .map(|(elem, _rest)| elem) } } @@ -80,18 +80,20 @@ impl fmt::Debug for Txt { struct Content<'a>(&'a Txt); impl fmt::Debug for Content<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut list = f.debug_list(); - for elem in self.0.iter() { - if let Ok(elem) = elem { - list.entry(&elem); - } else { - list.entry(&ParseError); - } - } - list.finish() + f.debug_list().entries(self.0.iter()).finish() } } f.debug_tuple("Txt").field(&Content(self)).finish() } } + +//--- Equality + +impl PartialEq for Txt { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl Eq for Txt {} diff --git a/src/new_rdata/basic/wks.rs b/src/new_rdata/basic/wks.rs index a02f5d82d..bacadeefa 100644 --- a/src/new_rdata/basic/wks.rs +++ b/src/new_rdata/basic/wks.rs @@ -60,3 +60,28 @@ impl BuildIntoMessage for Wks { self.as_bytes().build_into_message(builder) } } + +//--- Equality + +impl PartialEq for Wks { + fn eq(&self, other: &Self) -> bool { + if self.address != other.address || self.protocol != other.protocol { + return false; + } + + // Iterate through the ports, ignoring trailing zero bytes. + let mut lp = self.ports.iter(); + let mut rp = other.ports.iter(); + while lp.len() > 0 || rp.len() > 0 { + match (lp.next(), rp.next()) { + (Some(l), Some(r)) if l != r => return false, + (Some(l), None) if *l != 0 => return false, + (None, Some(r)) if *r != 0 => return false, + _ => {} + } + } + true + } +} + +impl Eq for Wks {} diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 70f041240..bd5c8697b 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -23,7 +23,7 @@ pub use edns::{EdnsOptionsIter, Opt}; //----------- RecordData ----------------------------------------------------- /// DNS record data. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] #[non_exhaustive] pub enum RecordData<'a, N> { /// The IPv4 address of a host responsible for this domain. @@ -181,7 +181,7 @@ impl BuildBytes for RecordData<'_, N> { //----------- UnknownRecordData ---------------------------------------------- /// Data for an unknown DNS record type. -#[derive(Debug, AsBytes, BuildBytes, ParseBytesByRef)] +#[derive(Debug, AsBytes, BuildBytes, ParseBytesByRef, PartialEq, Eq)] #[repr(transparent)] pub struct UnknownRecordData { /// The unparsed option data. From a8553938c5cedcb78459dc08859a20bd1235b29e Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 3 Mar 2025 13:48:09 +0100 Subject: [PATCH 141/167] [new_rdata] Add 'map_name(s)_by_ref()' --- src/new_rdata/basic/cname.rs | 21 ++++++++++++ src/new_rdata/basic/mx.rs | 23 +++++++++++++ src/new_rdata/basic/ns.rs | 21 ++++++++++++ src/new_rdata/basic/ptr.rs | 21 ++++++++++++ src/new_rdata/basic/soa.rs | 33 ++++++++++++++++++ src/new_rdata/mod.rs | 65 ++++++++++++++++++++++++++++++++++++ 6 files changed, 184 insertions(+) diff --git a/src/new_rdata/basic/cname.rs b/src/new_rdata/basic/cname.rs index c048a1019..24c794e25 100644 --- a/src/new_rdata/basic/cname.rs +++ b/src/new_rdata/basic/cname.rs @@ -28,6 +28,27 @@ pub struct CName { pub name: N, } +//--- Interaction + +impl CName { + /// Map the domain name within to another type. + pub fn map_name R>(self, f: F) -> CName { + CName { + name: (f)(self.name), + } + } + + /// Map a reference to the domain name within to another type. + pub fn map_name_by_ref<'r, R, F: FnOnce(&'r N) -> R>( + &'r self, + f: F, + ) -> CName { + CName { + name: (f)(&self.name), + } + } +} + //--- Parsing from DNS messages impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for CName { diff --git a/src/new_rdata/basic/mx.rs b/src/new_rdata/basic/mx.rs index af3f932df..bf420e517 100644 --- a/src/new_rdata/basic/mx.rs +++ b/src/new_rdata/basic/mx.rs @@ -31,6 +31,29 @@ pub struct Mx { pub exchange: N, } +//--- Interaction + +impl Mx { + /// Map the domain name within to another type. + pub fn map_name R>(self, f: F) -> Mx { + Mx { + preference: self.preference, + exchange: (f)(self.exchange), + } + } + + /// Map a reference to the domain name within to another type. + pub fn map_name_by_ref<'r, R, F: FnOnce(&'r N) -> R>( + &'r self, + f: F, + ) -> Mx { + Mx { + preference: self.preference, + exchange: (f)(&self.exchange), + } + } +} + //--- Parsing from DNS messages impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Mx { diff --git a/src/new_rdata/basic/ns.rs b/src/new_rdata/basic/ns.rs index ac0deea68..e90f1bdb2 100644 --- a/src/new_rdata/basic/ns.rs +++ b/src/new_rdata/basic/ns.rs @@ -28,6 +28,27 @@ pub struct Ns { pub name: N, } +//--- Interaction + +impl Ns { + /// Map the domain name within to another type. + pub fn map_name R>(self, f: F) -> Ns { + Ns { + name: (f)(self.name), + } + } + + /// Map a reference to the domain name within to another type. + pub fn map_name_by_ref<'r, R, F: FnOnce(&'r N) -> R>( + &'r self, + f: F, + ) -> Ns { + Ns { + name: (f)(&self.name), + } + } +} + //--- Parsing from DNS messages impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Ns { diff --git a/src/new_rdata/basic/ptr.rs b/src/new_rdata/basic/ptr.rs index e4eea4a8c..7cd530898 100644 --- a/src/new_rdata/basic/ptr.rs +++ b/src/new_rdata/basic/ptr.rs @@ -28,6 +28,27 @@ pub struct Ptr { pub name: N, } +//--- Interaction + +impl Ptr { + /// Map the domain name within to another type. + pub fn map_name R>(self, f: F) -> Ptr { + Ptr { + name: (f)(self.name), + } + } + + /// Map a reference to the domain name within to another type. + pub fn map_name_by_ref<'r, R, F: FnOnce(&'r N) -> R>( + &'r self, + f: F, + ) -> Ptr { + Ptr { + name: (f)(&self.name), + } + } +} + //--- Parsing from DNS messages impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Ptr { diff --git a/src/new_rdata/basic/soa.rs b/src/new_rdata/basic/soa.rs index f9d236bf4..604648c95 100644 --- a/src/new_rdata/basic/soa.rs +++ b/src/new_rdata/basic/soa.rs @@ -44,6 +44,39 @@ pub struct Soa { pub minimum: U32, } +//--- Interaction + +impl Soa { + /// Map the domain names within to another type. + pub fn map_names R>(self, mut f: F) -> Soa { + Soa { + mname: (f)(self.mname), + rname: (f)(self.rname), + serial: self.serial, + refresh: self.refresh, + retry: self.retry, + expire: self.expire, + minimum: self.minimum, + } + } + + /// Map references to the domain names within to another type. + pub fn map_names_by_ref<'r, R, F: FnMut(&'r N) -> R>( + &'r self, + mut f: F, + ) -> Soa { + Soa { + mname: (f)(&self.mname), + rname: (f)(&self.rname), + serial: self.serial, + refresh: self.refresh, + retry: self.retry, + expire: self.expire, + minimum: self.minimum, + } + } +} + //--- Parsing from DNS messages impl<'a, N: SplitMessageBytes<'a>> ParseMessageBytes<'a> for Soa { diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index bd5c8697b..19524a0bf 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -63,6 +63,71 @@ pub enum RecordData<'a, N> { Unknown(RType, &'a UnknownRecordData), } +//--- Inspection + +impl RecordData<'_, N> { + /// The type of this record data. + pub const fn rtype(&self) -> RType { + match self { + Self::A(..) => RType::A, + Self::Ns(..) => RType::NS, + Self::CName(..) => RType::CNAME, + Self::Soa(..) => RType::SOA, + Self::Wks(..) => RType::WKS, + Self::Ptr(..) => RType::PTR, + Self::HInfo(..) => RType::HINFO, + Self::Mx(..) => RType::MX, + Self::Txt(..) => RType::TXT, + Self::Aaaa(..) => RType::AAAA, + Self::Opt(..) => RType::OPT, + Self::Unknown(rtype, _) => *rtype, + } + } +} + +//--- Interaction + +impl<'a, N> RecordData<'a, N> { + /// Map the domain names within to another type. + pub fn map_names R>(self, f: F) -> RecordData<'a, R> { + match self { + Self::A(r) => RecordData::A(r), + Self::Ns(r) => RecordData::Ns(r.map_name(f)), + Self::CName(r) => RecordData::CName(r.map_name(f)), + Self::Soa(r) => RecordData::Soa(r.map_names(f)), + Self::Wks(r) => RecordData::Wks(r), + Self::Ptr(r) => RecordData::Ptr(r.map_name(f)), + Self::HInfo(r) => RecordData::HInfo(r), + Self::Mx(r) => RecordData::Mx(r.map_name(f)), + Self::Txt(r) => RecordData::Txt(r), + Self::Aaaa(r) => RecordData::Aaaa(r), + Self::Opt(r) => RecordData::Opt(r), + Self::Unknown(rt, rd) => RecordData::Unknown(rt, rd), + } + } + + /// Map references to the domain names within to another type. + pub fn map_names_by_ref<'r, R, F: FnMut(&'r N) -> R>( + &'r self, + f: F, + ) -> RecordData<'r, R> { + match self { + Self::A(r) => RecordData::A(r), + Self::Ns(r) => RecordData::Ns(r.map_name_by_ref(f)), + Self::CName(r) => RecordData::CName(r.map_name_by_ref(f)), + Self::Soa(r) => RecordData::Soa(r.map_names_by_ref(f)), + Self::Wks(r) => RecordData::Wks(r), + Self::Ptr(r) => RecordData::Ptr(r.map_name_by_ref(f)), + Self::HInfo(r) => RecordData::HInfo(r.clone()), + Self::Mx(r) => RecordData::Mx(r.map_name_by_ref(f)), + Self::Txt(r) => RecordData::Txt(r), + Self::Aaaa(r) => RecordData::Aaaa(r), + Self::Opt(r) => RecordData::Opt(r), + Self::Unknown(rt, rd) => RecordData::Unknown(*rt, rd), + } + } +} + //--- Parsing record data impl<'a, N> ParseRecordData<'a> for RecordData<'a, N> From 946ab216bf2841beb6045dd2c3e7a6f07c2f0cb9 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 7 Mar 2025 00:00:29 +0100 Subject: [PATCH 142/167] [new_base/wire] Make 'SizePrefixed' generic over size type Some DNSSEC types use single-byte prefixed data, for which I'd still like to make use of 'SizePrefixed' to avoid boilerplate. --- src/new_base/build/record.rs | 4 +- src/new_base/record.rs | 9 +- src/new_base/wire/ints.rs | 14 +++ src/new_base/wire/size_prefixed.rs | 191 +++++++++++++++++------------ src/new_edns/mod.rs | 8 +- 5 files changed, 140 insertions(+), 86 deletions(-) diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index 7d9a48b5d..e43495cac 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -5,7 +5,7 @@ use core::{mem::ManuallyDrop, ptr}; use crate::new_base::{ name::UnparsedName, parse::ParseMessageBytes, - wire::{AsBytes, ParseBytes, SizePrefixed, TruncationError}, + wire::{AsBytes, ParseBytes, SizePrefixed, TruncationError, U16}, RClass, RType, Record, TTL, }; @@ -57,7 +57,7 @@ impl<'b> RecordBuilder<'b> { b.append_bytes(record.rclass.as_bytes())?; b.append_bytes(record.ttl.as_bytes())?; let size = b.context().size; - SizePrefixed::new(&record.rdata) + SizePrefixed::::new(&record.rdata) .build_into_message(b.delegate())?; let data = (size + 2).try_into().expect("Messages are at most 64KiB"); diff --git a/src/new_base/record.rs b/src/new_base/record.rs index dd0771564..19af367a1 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -74,7 +74,7 @@ where let (&ttl, rest) = <&TTL>::split_message_bytes(contents, rest)?; let rdata_start = rest; let (_, rest) = - <&SizePrefixed<[u8]>>::split_message_bytes(contents, rest)?; + <&SizePrefixed>::split_message_bytes(contents, rest)?; let rdata = D::parse_record_data(&contents[..rest], rdata_start + 2, rtype)?; @@ -113,7 +113,7 @@ where builder.append_bytes(self.rtype.as_bytes())?; builder.append_bytes(self.rclass.as_bytes())?; builder.append_bytes(self.ttl.as_bytes())?; - SizePrefixed::new(&self.rdata) + SizePrefixed::::new(&self.rdata) .build_into_message(builder.delegate())?; Ok(builder.commit()) } @@ -131,7 +131,7 @@ where let (rtype, rest) = RType::split_bytes(rest)?; let (rclass, rest) = RClass::split_bytes(rest)?; let (ttl, rest) = TTL::split_bytes(rest)?; - let (rdata, rest) = <&SizePrefixed<[u8]>>::split_bytes(rest)?; + let (rdata, rest) = <&SizePrefixed>::split_bytes(rest)?; let rdata = D::parse_record_data_bytes(rdata, rtype)?; Ok((Self::new(rname, rtype, rclass, ttl, rdata), rest)) @@ -166,7 +166,8 @@ where bytes = self.rtype.as_bytes().build_bytes(bytes)?; bytes = self.rclass.as_bytes().build_bytes(bytes)?; bytes = self.ttl.as_bytes().build_bytes(bytes)?; - bytes = SizePrefixed::new(&self.rdata).build_bytes(bytes)?; + bytes = + SizePrefixed::::new(&self.rdata).build_bytes(bytes)?; Ok(bytes) } diff --git a/src/new_base/wire/ints.rs b/src/new_base/wire/ints.rs index a8170ea83..3753563da 100644 --- a/src/new_base/wire/ints.rs +++ b/src/new_base/wire/ints.rs @@ -292,3 +292,17 @@ define_int! { /// An unsigned 64-bit integer in network endianness. U64([u8; 8]) = u64; } + +impl TryFrom for U16 { + type Error = >::Error; + + fn try_from(value: usize) -> Result { + u16::try_from(value).map(U16::new) + } +} + +impl From for usize { + fn from(value: U16) -> Self { + usize::from(value.get()) + } +} diff --git a/src/new_base/wire/size_prefixed.rs b/src/new_base/wire/size_prefixed.rs index 9053c6431..bc403457a 100644 --- a/src/new_base/wire/size_prefixed.rs +++ b/src/new_base/wire/size_prefixed.rs @@ -2,6 +2,7 @@ use core::{ borrow::{Borrow, BorrowMut}, + fmt, ops::{Deref, DerefMut}, }; @@ -12,26 +13,26 @@ use crate::new_base::{ use super::{ AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, SplitBytes, - SplitBytesByRef, TruncationError, U16, + SplitBytesByRef, TruncationError, }; //----------- SizePrefixed --------------------------------------------------- -/// A wrapper adding a 16-bit size prefix to a message. +/// A wrapper adding a size prefix to a message. /// /// This is a common element in DNS messages (e.g. for record data and EDNS -/// options). When serialized as bytes, the inner value is prefixed with a -/// 16-bit network-endian integer indicating the length of the inner value in -/// bytes. +/// options). When serialized as bytes, the inner value is prefixed with an +/// integer (often a [`U16`](super::U16)) indicating the length of the inner +/// value in bytes. #[derive(Copy, Clone)] #[repr(C)] -pub struct SizePrefixed { +pub struct SizePrefixed { /// The size prefix (needed for 'ParseBytesByRef' / 'AsBytes'). /// /// This value is always consistent with the size of 'data' if it is /// (de)serialized in-place. By the bounds on 'ParseBytesByRef' and /// 'AsBytes', the serialized size is the same as 'size_of_val(&data)'. - size: U16, + size: S, /// The inner data. data: T, @@ -39,21 +40,21 @@ pub struct SizePrefixed { //--- Construction -impl SizePrefixed { - const VALID_SIZE: () = assert!(core::mem::size_of::() < 65536); - +impl SizePrefixed +where + S: TryFrom, +{ /// Construct a [`SizePrefixed`]. /// /// # Panics /// - /// Panics if the data is 64KiB or more in size. - pub const fn new(data: T) -> Self { - // Force the 'VALID_SIZE' assertion to be evaluated. - #[allow(clippy::let_unit_value)] - let _ = Self::VALID_SIZE; - + /// Panics if the data size does not fit in `S`. + pub fn new(data: T) -> Self { + let size = core::mem::size_of::(); Self { - size: U16::new(core::mem::size_of::() as u16), + size: S::try_from(size).unwrap_or_else(|_| { + panic!("`data.len()` does not fit in the size field") + }), data, } } @@ -61,7 +62,10 @@ impl SizePrefixed { //--- Conversion from the inner data -impl From for SizePrefixed { +impl From for SizePrefixed +where + S: TryFrom, +{ fn from(value: T) -> Self { Self::new(value) } @@ -69,7 +73,7 @@ impl From for SizePrefixed { //--- Access to the inner data -impl Deref for SizePrefixed { +impl Deref for SizePrefixed { type Target = T; fn deref(&self) -> &Self::Target { @@ -77,31 +81,31 @@ impl Deref for SizePrefixed { } } -impl DerefMut for SizePrefixed { +impl DerefMut for SizePrefixed { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.data } } -impl Borrow for SizePrefixed { +impl Borrow for SizePrefixed { fn borrow(&self) -> &T { &self.data } } -impl BorrowMut for SizePrefixed { +impl BorrowMut for SizePrefixed { fn borrow_mut(&mut self) -> &mut T { &mut self.data } } -impl AsRef for SizePrefixed { +impl AsRef for SizePrefixed { fn as_ref(&self) -> &T { &self.data } } -impl AsMut for SizePrefixed { +impl AsMut for SizePrefixed { fn as_mut(&mut self) -> &mut T { &mut self.data } @@ -109,26 +113,36 @@ impl AsMut for SizePrefixed { //--- Parsing from DNS messages -impl<'b, T: ParseMessageBytes<'b>> ParseMessageBytes<'b> for SizePrefixed { +impl<'b, S, T: ParseMessageBytes<'b>> ParseMessageBytes<'b> + for SizePrefixed +where + S: SplitMessageBytes<'b> + TryFrom + TryInto, +{ fn parse_message_bytes( contents: &'b [u8], start: usize, ) -> Result { - let (&size, rest) = <&U16>::split_message_bytes(contents, start)?; - if rest + size.get() as usize != contents.len() { + let (size, rest) = S::split_message_bytes(contents, start)?; + let size = size.try_into().map_err(|_| ParseError)?; + if rest + size != contents.len() { return Err(ParseError); } T::parse_message_bytes(contents, rest).map(Self::new) } } -impl<'b, T: ParseMessageBytes<'b>> SplitMessageBytes<'b> for SizePrefixed { +impl<'b, S, T: ParseMessageBytes<'b>> SplitMessageBytes<'b> + for SizePrefixed +where + S: SplitMessageBytes<'b> + TryFrom + TryInto, +{ fn split_message_bytes( contents: &'b [u8], start: usize, ) -> Result<(Self, usize), ParseError> { - let (&size, rest) = <&U16>::split_message_bytes(contents, start)?; - let (start, rest) = (rest, rest + size.get() as usize); + let (size, rest) = S::split_message_bytes(contents, start)?; + let size = size.try_into().map_err(|_| ParseError)?; + let (start, rest) = (rest, rest + size); let contents = contents.get(..rest).ok_or(ParseError)?; let data = T::parse_message_bytes(contents, start)?; Ok((Self::new(data), rest)) @@ -137,33 +151,46 @@ impl<'b, T: ParseMessageBytes<'b>> SplitMessageBytes<'b> for SizePrefixed { //--- Parsing from bytes -impl<'b, T: ParseBytes<'b>> ParseBytes<'b> for SizePrefixed { +impl<'b, S, T: ParseBytes<'b>> ParseBytes<'b> for SizePrefixed +where + S: SplitBytes<'b> + TryFrom + TryInto, +{ fn parse_bytes(bytes: &'b [u8]) -> Result { - let (size, rest) = U16::split_bytes(bytes)?; - if rest.len() != size.get() as usize { + let (size, rest) = S::split_bytes(bytes)?; + let size = size.try_into().map_err(|_| ParseError)?; + if rest.len() != size { return Err(ParseError); } T::parse_bytes(bytes).map(Self::new) } } -impl<'b, T: ParseBytes<'b>> SplitBytes<'b> for SizePrefixed { +impl<'b, S, T: ParseBytes<'b>> SplitBytes<'b> for SizePrefixed +where + S: SplitBytes<'b> + TryFrom + TryInto, +{ fn split_bytes(bytes: &'b [u8]) -> Result<(Self, &'b [u8]), ParseError> { - let (size, rest) = U16::split_bytes(bytes)?; - if rest.len() < size.get() as usize { + let (size, rest) = S::split_bytes(bytes)?; + let size = size.try_into().map_err(|_| ParseError)?; + if rest.len() < size { return Err(ParseError); } - let (data, rest) = rest.split_at(size.get() as usize); + let (data, rest) = rest.split_at(size); let data = T::parse_bytes(data)?; Ok((Self::new(data), rest)) } } -unsafe impl ParseBytesByRef for SizePrefixed { +unsafe impl ParseBytesByRef + for SizePrefixed +where + S: SplitBytesByRef + Copy + TryFrom + TryInto, +{ fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { let addr = bytes.as_ptr(); - let (size, rest) = U16::split_bytes_by_ref(bytes)?; - if rest.len() != size.get() as usize { + let (&size, rest) = S::split_bytes_by_ref(bytes)?; + let size = size.try_into().map_err(|_| ParseError)?; + if rest.len() != size { return Err(ParseError); } let last = T::parse_bytes_by_ref(rest)?; @@ -179,8 +206,9 @@ unsafe impl ParseBytesByRef for SizePrefixed { fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError> { let addr = bytes.as_ptr(); - let (size, rest) = U16::split_bytes_by_mut(bytes)?; - if rest.len() != size.get() as usize { + let (&mut size, rest) = S::split_bytes_by_mut(bytes)?; + let size = size.try_into().map_err(|_| ParseError)?; + if rest.len() != size { return Err(ParseError); } let last = T::parse_bytes_by_mut(rest)?; @@ -199,16 +227,21 @@ unsafe impl ParseBytesByRef for SizePrefixed { } } -unsafe impl SplitBytesByRef for SizePrefixed { +unsafe impl SplitBytesByRef + for SizePrefixed +where + S: SplitBytesByRef + Copy + TryFrom + TryInto, +{ fn split_bytes_by_ref( bytes: &[u8], ) -> Result<(&Self, &[u8]), ParseError> { let addr = bytes.as_ptr(); - let (size, rest) = U16::split_bytes_by_ref(bytes)?; - if rest.len() < size.get() as usize { + let (&size, rest) = S::split_bytes_by_ref(bytes)?; + let size = size.try_into().map_err(|_| ParseError)?; + if rest.len() < size { return Err(ParseError); } - let (data, rest) = rest.split_at(size.get() as usize); + let (data, rest) = rest.split_at(size); let last = T::parse_bytes_by_ref(data)?; let ptr = last.ptr_with_address(addr as *const ()); @@ -224,11 +257,12 @@ unsafe impl SplitBytesByRef for SizePrefixed { bytes: &mut [u8], ) -> Result<(&mut Self, &mut [u8]), ParseError> { let addr = bytes.as_ptr(); - let (size, rest) = U16::split_bytes_by_mut(bytes)?; - if rest.len() < size.get() as usize { + let (&mut size, rest) = S::split_bytes_by_mut(bytes)?; + let size = size.try_into().map_err(|_| ParseError)?; + if rest.len() < size { return Err(ParseError); } - let (data, rest) = rest.split_at_mut(size.get() as usize); + let (data, rest) = rest.split_at_mut(size); let last = T::parse_bytes_by_mut(data)?; let ptr = last.ptr_with_address(addr as *const ()); @@ -243,57 +277,62 @@ unsafe impl SplitBytesByRef for SizePrefixed { //--- Building into DNS messages -impl BuildIntoMessage for SizePrefixed { +impl BuildIntoMessage for SizePrefixed +where + S: AsBytes + Default + TryFrom, +{ fn build_into_message( &self, mut builder: build::Builder<'_>, ) -> BuildResult { assert_eq!(builder.uncommitted(), &[] as &[u8]); - builder.append_bytes(&0u16.to_be_bytes())?; + let size_size = core::mem::size_of::(); + builder.append_bytes(S::default().as_bytes())?; self.data.build_into_message(builder.delegate())?; - let size = builder.uncommitted().len() - 2; - let size = u16::try_from(size).expect("the data never exceeds 64KiB"); - // SAFETY: A 'U16' is being modified, not a domain name. - let size_buf = unsafe { &mut builder.uncommitted_mut()[0..2] }; - size_buf.copy_from_slice(&size.to_be_bytes()); + let size = builder.uncommitted().len() - size_size; + let size = S::try_from(size).unwrap_or_else(|_| { + panic!("`data.len()` does not fit in the size field") + }); + // SAFETY: An 'S' is being modified, not a domain name. + let size_buf = unsafe { &mut builder.uncommitted_mut()[..size_size] }; + size_buf.copy_from_slice(&size.as_bytes()); Ok(builder.commit()) } } //--- Building into byte strings -impl BuildBytes for SizePrefixed { +impl BuildBytes for SizePrefixed +where + S: AsBytes + Default + TryFrom, +{ fn build_bytes<'b>( &self, bytes: &'b mut [u8], ) -> Result<&'b mut [u8], TruncationError> { // Get the size area to fill in afterwards. - let (size_buf, data_buf) = - U16::split_bytes_by_mut(bytes).map_err(|_| TruncationError)?; + let size_size = core::mem::size_of::(); + if bytes.len() < size_size { + return Err(TruncationError); + } + let (size_buf, data_buf) = bytes.split_at_mut(size_size); let data_buf_len = data_buf.len(); let rest = self.data.build_bytes(data_buf)?; let size = data_buf_len - rest.len(); - assert!(size < 65536, "Cannot serialize >=64KiB into 16-bit integer"); - *size_buf = U16::new(size as u16); + let size = S::try_from(size).unwrap_or_else(|_| { + panic!("`data.len()` does not fit in the size field") + }); + size_buf.copy_from_slice(size.as_bytes()); Ok(rest) } } -unsafe impl AsBytes for SizePrefixed { - // For debugging, we check that the serialized size is correct. - #[cfg(debug_assertions)] - fn as_bytes(&self) -> &[u8] { - let size: usize = self.size.get().into(); - assert_eq!(size, core::mem::size_of_val(&self.data)); +unsafe impl AsBytes for SizePrefixed {} - // SAFETY: - // - 'Self' has no padding bytes and no interior mutability. - // - Its size in memory is exactly 'size_of_val(self)'. - unsafe { - core::slice::from_raw_parts( - self as *const Self as *const u8, - core::mem::size_of_val(self), - ) - } +//--- Formatting + +impl fmt::Debug for SizePrefixed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("SizePrefixed").field(&&self.data).finish() } } diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index f2cf6b710..fa16a241d 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -43,7 +43,7 @@ pub struct EdnsRecord<'a> { pub flags: EdnsFlags, /// Extended DNS options. - pub options: SizePrefixed<&'a Opt>, + pub options: SizePrefixed, } //--- Parsing from DNS messages @@ -78,7 +78,7 @@ impl<'a> SplitBytes<'a> for EdnsRecord<'a> { let (&ext_rcode, rest) = <&u8>::split_bytes(rest)?; let (&version, rest) = <&u8>::split_bytes(rest)?; let (&flags, rest) = <&EdnsFlags>::split_bytes(rest)?; - let (options, rest) = >::split_bytes(rest)?; + let (options, rest) = >::split_bytes(rest)?; Ok(( Self { @@ -235,7 +235,7 @@ impl<'b> ParseBytes<'b> for EdnsOption<'b> { impl<'b> SplitBytes<'b> for EdnsOption<'b> { fn split_bytes(bytes: &'b [u8]) -> Result<(Self, &'b [u8]), ParseError> { let (code, rest) = OptionCode::split_bytes(bytes)?; - let (data, rest) = <&SizePrefixed<[u8]>>::split_bytes(rest)?; + let (data, rest) = <&SizePrefixed>::split_bytes(rest)?; let this = match code { OptionCode::COOKIE => match data.len() { @@ -272,7 +272,7 @@ impl BuildBytes for EdnsOption<'_> { Self::ExtError(this) => this.as_bytes(), Self::Unknown(_, this) => this.as_bytes(), }; - bytes = SizePrefixed::new(data).build_bytes(bytes)?; + bytes = SizePrefixed::::new(data).build_bytes(bytes)?; Ok(bytes) } From 79c4fecaf657f275b73528d8c5708c0e88dbcd77 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 7 Mar 2025 00:01:25 +0100 Subject: [PATCH 143/167] [new_rdata] Implement stubs of DNSSEC record types --- src/new_rdata/dnssec/dnskey.rs | 44 +++++++++++ src/new_rdata/dnssec/ds.rs | 68 +++++++++++++++++ src/new_rdata/dnssec/mod.rs | 68 +++++++++++++++++ src/new_rdata/dnssec/nsec.rs | 127 +++++++++++++++++++++++++++++++ src/new_rdata/dnssec/nsec3.rs | 135 +++++++++++++++++++++++++++++++++ src/new_rdata/dnssec/rrsig.rs | 38 ++++++++++ src/new_rdata/mod.rs | 3 + 7 files changed, 483 insertions(+) create mode 100644 src/new_rdata/dnssec/dnskey.rs create mode 100644 src/new_rdata/dnssec/ds.rs create mode 100644 src/new_rdata/dnssec/mod.rs create mode 100644 src/new_rdata/dnssec/nsec.rs create mode 100644 src/new_rdata/dnssec/nsec3.rs create mode 100644 src/new_rdata/dnssec/rrsig.rs diff --git a/src/new_rdata/dnssec/dnskey.rs b/src/new_rdata/dnssec/dnskey.rs new file mode 100644 index 000000000..8a7b7c80a --- /dev/null +++ b/src/new_rdata/dnssec/dnskey.rs @@ -0,0 +1,44 @@ +use domain_macros::*; + +use crate::new_base::wire::U16; + +use super::SecAlg; + +//----------- DNSKey --------------------------------------------------------- + +/// A cryptographic key for signing DNS records. +#[derive(AsBytes, BuildBytes, ParseBytesByRef)] +#[repr(C)] +pub struct DNSKey { + /// Flags describing the usage of the key. + pub flags: DNSKeyFlags, + + /// The protocol version of the key. + pub protocol: u8, + + /// The cryptographic algorithm used by this key. + pub algorithm: SecAlg, + + /// The serialized public key. + pub key: [u8], +} + +//----------- DNSKeyFlags ---------------------------------------------------- + +/// Flags describing a [`DNSKey`]. +#[derive( + Copy, + Clone, + Default, + Hash, + AsBytes, + BuildBytes, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, +)] +#[repr(transparent)] +pub struct DNSKeyFlags { + inner: U16, +} diff --git a/src/new_rdata/dnssec/ds.rs b/src/new_rdata/dnssec/ds.rs new file mode 100644 index 000000000..270e5a080 --- /dev/null +++ b/src/new_rdata/dnssec/ds.rs @@ -0,0 +1,68 @@ +use core::fmt; + +use domain_macros::*; + +use crate::new_base::wire::U16; + +use super::SecAlg; + +//----------- Ds ------------------------------------------------------------- + +/// The signing key for a delegated zone. +#[derive(Debug, AsBytes, BuildBytes, ParseBytesByRef)] +#[repr(C)] +pub struct Ds { + /// The key tag of the signing key. + pub keytag: U16, + + /// The cryptographic algorithm used by the signing key. + pub algorithm: SecAlg, + + /// The algorithm used to calculate the key digest. + pub digest_type: DigestType, + + /// A serialized digest of the signing key. + pub digest: [u8], +} + +//----------- DigestType ----------------------------------------------------- + +/// A cryptographic digest algorithm. +#[derive( + Copy, + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + AsBytes, + BuildBytes, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, +)] +#[repr(transparent)] +pub struct DigestType { + /// The algorithm code. + pub code: u8, +} + +//--- Associated Constants + +impl DigestType { + /// The SHA-1 algorithm. + pub const SHA1: Self = Self { code: 1 }; +} + +//--- Formatting + +impl fmt::Debug for DigestType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match *self { + Self::SHA1 => "DigestType::SHA1", + _ => return write!(f, "DigestType({})", self.code), + }) + } +} diff --git a/src/new_rdata/dnssec/mod.rs b/src/new_rdata/dnssec/mod.rs new file mode 100644 index 000000000..5422cea4f --- /dev/null +++ b/src/new_rdata/dnssec/mod.rs @@ -0,0 +1,68 @@ +//! Record types relating to DNSSEC. + +use core::fmt; + +use domain_macros::*; + +//----------- Submodules ----------------------------------------------------- + +mod dnskey; +pub use dnskey::DNSKey; + +mod rrsig; +pub use rrsig::RRSig; + +mod nsec; +pub use nsec::{NSec, TypeBitmaps}; + +mod nsec3; +pub use nsec3::NSec3; + +mod ds; +pub use ds::Ds; + +//----------- SecAlg --------------------------------------------------------- + +/// A cryptographic algorithm for DNS security. +#[derive( + Copy, + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + AsBytes, + BuildBytes, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, +)] +#[repr(transparent)] +pub struct SecAlg { + /// The algorithm code. + pub code: u8, +} + +//--- Associated Constants + +impl SecAlg { + /// The DSA/SHA-1 algorithm. + pub const DSA_SHA1: Self = Self { code: 3 }; + + /// The RSA/SHA-1 algorithm. + pub const RSA_SHA1: Self = Self { code: 5 }; +} + +//--- Formatting + +impl fmt::Debug for SecAlg { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match *self { + Self::DSA_SHA1 => "SecAlg::DSA_SHA1", + Self::RSA_SHA1 => "SecAlg::RSA_SHA1", + _ => return write!(f, "SecAlg({})", self.code), + }) + } +} diff --git a/src/new_rdata/dnssec/nsec.rs b/src/new_rdata/dnssec/nsec.rs new file mode 100644 index 000000000..682426ea7 --- /dev/null +++ b/src/new_rdata/dnssec/nsec.rs @@ -0,0 +1,127 @@ +use core::{fmt, mem}; + +use domain_macros::*; + +use crate::new_base::{ + name::Name, + wire::{ParseBytesByRef, ParseError}, + RType, +}; + +//----------- NSec ----------------------------------------------------------- + +/// An indication of the non-existence of a set of DNS records (version 1). +#[derive(Clone, Debug, BuildBytes, ParseBytes)] +pub struct NSec<'a> { + /// The name of the next existing DNS record. + pub next: &'a Name, + + /// The types of the records that exist at this owner name. + pub types: &'a TypeBitmaps, +} + +//----------- TypeBitmaps ---------------------------------------------------- + +/// A bitmap of DNS record types. +#[derive(AsBytes)] +#[repr(transparent)] +pub struct TypeBitmaps { + octets: [u8], +} + +//--- Inspection + +impl TypeBitmaps { + /// The types in this bitmap. + pub fn types(&self) -> impl Iterator + '_ { + fn split_window(octets: &[u8]) -> Option<(u8, &[u8], &[u8])> { + let &[num, len, ref rest @ ..] = octets else { + return None; + }; + + let (bits, rest) = rest.split_at(len as usize); + Some((num, bits, rest)) + } + + core::iter::successors(split_window(&self.octets), |(_, _, rest)| { + split_window(rest) + }) + .flat_map(move |(num, bits, _)| { + bits.iter().enumerate().flat_map(move |(i, &b)| { + (0..8).filter(move |&j| ((b >> j) & 1) != 0).map(move |j| { + RType::from(u16::from_be_bytes([num, (i * 8 + j) as u8])) + }) + }) + }) + } +} + +//--- Formatting + +impl fmt::Debug for TypeBitmaps { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_set().entries(self.types()).finish() + } +} + +//--- Parsing + +impl TypeBitmaps { + fn validate_bytes(mut octets: &[u8]) -> Result<(), ParseError> { + // At least one bitmap is mandatory. + let mut num = octets.first().ok_or(ParseError)?; + octets = Self::validate_window_bytes(octets)?; + + while let Some(next) = octets.first() { + if mem::replace(&mut num, next) >= next { + return Err(ParseError); + } + + octets = Self::validate_window_bytes(octets)?; + } + + Ok(()) + } + + fn validate_window_bytes(octets: &[u8]) -> Result<&[u8], ParseError> { + let &[_num, len, ref rest @ ..] = octets else { + return Err(ParseError); + }; + + if !(1..=32).contains(&len) || rest.len() < len as usize { + return Err(ParseError); + } + + let (bits, rest) = rest.split_at(len as usize); + if bits.last() == Some(&0) { + // Trailing zeros are not allowed. + return Err(ParseError); + } + + Ok(rest) + } +} + +// SAFETY: The implementations of 'parse_bytes_by_{ref,mut}()' always parse +// the entirety of the input on success, satisfying the safety requirements. +unsafe impl ParseBytesByRef for TypeBitmaps { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + Self::validate_bytes(bytes)?; + + // SAFETY: 'TypeBitmaps' is 'repr(transparent)' to '[u8]', and so + // references to '[u8]' can be transmuted to 'TypeBitmaps' soundly. + unsafe { core::mem::transmute(bytes) } + } + + fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError> { + Self::validate_bytes(bytes)?; + + // SAFETY: 'TypeBitmaps' is 'repr(transparent)' to '[u8]', and so + // references to '[u8]' can be transmuted to 'TypeBitmaps' soundly. + unsafe { core::mem::transmute(bytes) } + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + self.octets.ptr_with_address(addr) as *const Self + } +} diff --git a/src/new_rdata/dnssec/nsec3.rs b/src/new_rdata/dnssec/nsec3.rs new file mode 100644 index 000000000..732766224 --- /dev/null +++ b/src/new_rdata/dnssec/nsec3.rs @@ -0,0 +1,135 @@ +use core::fmt; + +use domain_macros::*; + +use crate::new_base::wire::{SizePrefixed, U16}; + +use super::TypeBitmaps; + +//----------- NSec3 ---------------------------------------------------------- + +/// An indication of the non-existence of a DNS record (version 3). +#[derive(Clone, Debug, BuildBytes, ParseBytes)] +pub struct NSec3<'a> { + /// The algorithm used to hash names. + pub algorithm: NSec3HashAlg, + + /// Flags modifying the behaviour of the record. + pub flags: NSec3Flags, + + /// The number of iterations of the underlying hash function per name. + pub iterations: U16, + + /// The salt used to randomize the hash function. + pub salt: &'a SizePrefixed, + + /// The name of the next existing DNS record. + pub next: &'a SizePrefixed, + + /// The types of the records that exist at this owner name. + pub types: &'a TypeBitmaps, +} + +//----------- NSec3HashAlg --------------------------------------------------- + +/// The hash algorithm used with [`NSec3`] records. +#[derive( + Copy, + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + AsBytes, + BuildBytes, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, +)] +#[repr(transparent)] +pub struct NSec3HashAlg { + /// The algorithm code. + pub code: u8, +} + +//--- Associated Constants + +impl NSec3HashAlg { + /// The SHA-1 algorithm. + pub const SHA1: Self = Self { code: 1 }; +} + +//--- Formatting + +impl fmt::Debug for NSec3HashAlg { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match *self { + Self::SHA1 => "NSec3HashAlg::SHA1", + _ => return write!(f, "NSec3HashAlg({})", self.code), + }) + } +} + +//----------- NSec3Flags ----------------------------------------------------- + +/// Flags modifying the behaviour of an [`NSec3`] record. +#[derive( + Copy, + Clone, + Default, + Hash, + AsBytes, + BuildBytes, + ParseBytes, + ParseBytesByRef, + SplitBytes, + SplitBytesByRef, +)] +#[repr(transparent)] +pub struct NSec3Flags { + inner: u8, +} + +//--- Interaction + +impl NSec3Flags { + /// Get the specified flag bit. + fn get_flag(&self, pos: u32) -> bool { + self.inner & (1 << pos) != 0 + } + + /// Set the specified flag bit. + fn set_flag(mut self, pos: u32, value: bool) -> Self { + self.inner &= !(1 << pos); + self.inner |= (value as u8) << pos; + self + } + + /// The raw flags bits. + pub fn bits(&self) -> u8 { + self.inner + } + + /// Whether unsigned delegations can exist in the covered range. + pub fn is_optout(&self) -> bool { + self.get_flag(0) + } + + /// Allow unsigned delegations to exist in the covered raneg. + pub fn set_optout(self, value: bool) -> Self { + self.set_flag(0, value) + } +} + +//--- Formatting + +impl fmt::Debug for NSec3Flags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("NSec3Flags") + .field("optout", &self.is_optout()) + .field("bits", &self.bits()) + .finish() + } +} diff --git a/src/new_rdata/dnssec/rrsig.rs b/src/new_rdata/dnssec/rrsig.rs new file mode 100644 index 000000000..ea9a6f5c6 --- /dev/null +++ b/src/new_rdata/dnssec/rrsig.rs @@ -0,0 +1,38 @@ +use domain_macros::*; + +use crate::new_base::{name::Name, wire::U16, RType, Serial, TTL}; + +use super::SecAlg; + +//----------- RRSig ---------------------------------------------------------- + +/// A cryptographic signature on a DNS record set. +#[derive(Clone, Debug, PartialEq, Eq, BuildBytes, ParseBytes)] +pub struct RRSig<'a> { + /// The type of the RRset being signed. + pub rtype: RType, + + /// The cryptographic algorithm used to construct the signature. + pub algorithm: SecAlg, + + /// The number of labels in the signed RRset's owner name. + pub labels: u8, + + /// The (original) TTL of the signed RRset. + pub ttl: TTL, + + /// The point in time when the signature expires. + pub expiration: Serial, + + /// The point in time when the signature was created. + pub inception: Serial, + + /// The key tag of the key used to make the signature. + pub keytag: U16, + + /// The name identifying the signer. + pub signer: &'a Name, + + /// The serialized cryptographic signature. + pub signature: &'a [u8], +} diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 19524a0bf..c5cdda0a1 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -20,6 +20,9 @@ pub use ipv6::Aaaa; mod edns; pub use edns::{EdnsOptionsIter, Opt}; +mod dnssec; +pub use dnssec::{DNSKey, Ds, NSec, NSec3, RRSig}; + //----------- RecordData ----------------------------------------------------- /// DNS record data. From 31fea55f9336e7eeda3ee156575cf3d9cdf8107e Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 7 Mar 2025 10:35:09 +0100 Subject: [PATCH 144/167] [src/new_base/size_prefixed] Fix clippy lint --- src/new_base/wire/size_prefixed.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/new_base/wire/size_prefixed.rs b/src/new_base/wire/size_prefixed.rs index bc403457a..baec485cc 100644 --- a/src/new_base/wire/size_prefixed.rs +++ b/src/new_base/wire/size_prefixed.rs @@ -295,7 +295,7 @@ where }); // SAFETY: An 'S' is being modified, not a domain name. let size_buf = unsafe { &mut builder.uncommitted_mut()[..size_size] }; - size_buf.copy_from_slice(&size.as_bytes()); + size_buf.copy_from_slice(size.as_bytes()); Ok(builder.commit()) } } From 10be8bfdcb3e01d16b84f88d0782f2dd518973e7 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 7 Mar 2025 10:35:30 +0100 Subject: [PATCH 145/167] [new_rdata/dnssec] Export helper types --- src/new_rdata/dnssec/mod.rs | 4 ++-- src/new_rdata/mod.rs | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/new_rdata/dnssec/mod.rs b/src/new_rdata/dnssec/mod.rs index 5422cea4f..705efc673 100644 --- a/src/new_rdata/dnssec/mod.rs +++ b/src/new_rdata/dnssec/mod.rs @@ -16,10 +16,10 @@ mod nsec; pub use nsec::{NSec, TypeBitmaps}; mod nsec3; -pub use nsec3::NSec3; +pub use nsec3::{NSec3, NSec3Flags, NSec3HashAlg}; mod ds; -pub use ds::Ds; +pub use ds::{DigestType, Ds}; //----------- SecAlg --------------------------------------------------------- diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index c5cdda0a1..8a7c2f491 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -21,7 +21,10 @@ mod edns; pub use edns::{EdnsOptionsIter, Opt}; mod dnssec; -pub use dnssec::{DNSKey, Ds, NSec, NSec3, RRSig}; +pub use dnssec::{ + DNSKey, DigestType, Ds, NSec, NSec3, NSec3Flags, NSec3HashAlg, RRSig, + SecAlg, +}; //----------- RecordData ----------------------------------------------------- From 3d57844ea00587fdf83e2226822ed6fc7c8336b2 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 7 Mar 2025 11:39:33 +0100 Subject: [PATCH 146/167] [new_rdata] Avoid '{BuildInto,ParseFrom}Message' where possible --- src/new_base/build/builder.rs | 13 +++++++++ src/new_base/build/message.rs | 15 ++++++---- src/new_base/record.rs | 16 +++++++++++ src/new_rdata/basic/a.rs | 13 +-------- src/new_rdata/basic/hinfo.rs | 31 +-------------------- src/new_rdata/basic/wks.rs | 13 +-------- src/new_rdata/edns.rs | 15 +--------- src/new_rdata/mod.rs | 52 +++++++++++++++++++---------------- 8 files changed, 72 insertions(+), 96 deletions(-) diff --git a/src/new_base/build/builder.rs b/src/new_base/build/builder.rs index 34c415ce3..175c8bd78 100644 --- a/src/new_base/build/builder.rs +++ b/src/new_base/build/builder.rs @@ -396,6 +396,19 @@ impl Builder<'_> { self.append_with(bytes.len(), |buffer| buffer.copy_from_slice(bytes)) } + /// Serialize an object into bytes and append it. + /// + /// No name compression will be performed. + pub fn append_built_bytes( + &mut self, + object: &impl BuildBytes, + ) -> Result<(), TruncationError> { + let rest = object.build_bytes(self.uninitialized())?.len(); + let appended = self.uninitialized().len() - rest; + self.mark_appended(appended); + Ok(()) + } + /// Compress and append a domain name. pub fn append_name( &mut self, diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 0dc2bdd95..87767528f 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -317,7 +317,7 @@ impl<'b> MessageBuilder<'b, '_> { mod test { use crate::{ new_base::{ - build::{BuildIntoMessage, BuilderContext, MessageState}, + build::{BuilderContext, MessageState}, name::RevName, wire::U16, QClass, QType, Question, RClass, RType, Record, SectionCounts, @@ -425,10 +425,15 @@ mod test { assert!(rb.delegate().append_bytes(&[0u8; 5]).is_err()); - let rdata = A { - octets: [127, 0, 0, 1], - }; - rdata.build_into_message(rb.delegate()).unwrap(); + { + let mut builder = rb.delegate(); + builder + .append_built_bytes(&A { + octets: [127, 0, 0, 1], + }) + .unwrap(); + builder.commit(); + } assert_eq!(rb.rdata(), b"\x7F\x00\x00\x01"); } diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 19af367a1..3af5fd191 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -240,6 +240,22 @@ impl RType { pub const OPT: Self = Self::new(41); } +//--- Conversion to and from 'u16' + +impl From for RType { + fn from(value: u16) -> Self { + Self { + code: U16::new(value), + } + } +} + +impl From for u16 { + fn from(value: RType) -> Self { + value.code.get() + } +} + //--- Formatting impl fmt::Debug for RType { diff --git a/src/new_rdata/basic/a.rs b/src/new_rdata/basic/a.rs index 3f8d8b2aa..14700fdd5 100644 --- a/src/new_rdata/basic/a.rs +++ b/src/new_rdata/basic/a.rs @@ -4,10 +4,7 @@ use core::str::FromStr; use domain_macros::*; -use crate::new_base::{ - build::{self, BuildIntoMessage, BuildResult}, - wire::AsBytes, -}; +use crate::new_base::wire::AsBytes; //----------- A -------------------------------------------------------------- @@ -67,11 +64,3 @@ impl fmt::Display for A { Ipv4Addr::from(*self).fmt(f) } } - -//--- Building into DNS messages - -impl BuildIntoMessage for A { - fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { - self.as_bytes().build_into_message(builder) - } -} diff --git a/src/new_rdata/basic/hinfo.rs b/src/new_rdata/basic/hinfo.rs index 524653f7a..0151c97d3 100644 --- a/src/new_rdata/basic/hinfo.rs +++ b/src/new_rdata/basic/hinfo.rs @@ -1,11 +1,6 @@ use domain_macros::*; -use crate::new_base::{ - build::{self, BuildIntoMessage, BuildResult}, - parse::ParseMessageBytes, - wire::{ParseBytes, ParseError}, - CharStr, -}; +use crate::new_base::CharStr; //----------- HInfo ---------------------------------------------------------- @@ -18,27 +13,3 @@ pub struct HInfo<'a> { /// The OS type. pub os: &'a CharStr, } - -//--- Parsing from DNS messages - -impl<'a> ParseMessageBytes<'a> for HInfo<'a> { - fn parse_message_bytes( - contents: &'a [u8], - start: usize, - ) -> Result { - Self::parse_bytes(&contents[start..]) - } -} - -//--- Building into DNS messages - -impl BuildIntoMessage for HInfo<'_> { - fn build_into_message( - &self, - mut builder: build::Builder<'_>, - ) -> BuildResult { - self.cpu.build_into_message(builder.delegate())?; - self.os.build_into_message(builder.delegate())?; - Ok(builder.commit()) - } -} diff --git a/src/new_rdata/basic/wks.rs b/src/new_rdata/basic/wks.rs index bacadeefa..9e29d4133 100644 --- a/src/new_rdata/basic/wks.rs +++ b/src/new_rdata/basic/wks.rs @@ -2,10 +2,7 @@ use core::fmt; use domain_macros::*; -use crate::new_base::{ - build::{self, BuildIntoMessage, BuildResult}, - wire::AsBytes, -}; +use crate::new_base::wire::AsBytes; use super::A; @@ -53,14 +50,6 @@ impl fmt::Debug for Wks { } } -//--- Building into DNS messages - -impl BuildIntoMessage for Wks { - fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { - self.as_bytes().build_into_message(builder) - } -} - //--- Equality impl PartialEq for Wks { diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index 43327b50c..46bba4ed3 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -7,10 +7,7 @@ use core::{fmt, iter::FusedIterator}; use domain_macros::*; use crate::{ - new_base::{ - build::{self, BuildIntoMessage, BuildResult}, - wire::{ParseError, SplitBytes}, - }, + new_base::wire::{ParseError, SplitBytes}, new_edns::EdnsOption, }; @@ -43,16 +40,6 @@ impl fmt::Debug for Opt { } } -// TODO: Formatting. - -//--- Building into DNS messages - -impl BuildIntoMessage for Opt { - fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { - self.contents.build_into_message(builder) - } -} - //----------- EdnsOptionsIter ------------------------------------------------ /// An iterator over EDNS options in an [`Opt`] record. diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 8a7c2f491..807baf0fd 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -5,7 +5,10 @@ use domain_macros::*; use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, parse::{ParseMessageBytes, SplitMessageBytes}, - wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, + wire::{ + AsBytes, BuildBytes, ParseBytes, ParseError, SplitBytes, + TruncationError, + }, ParseRecordData, RType, }; @@ -146,9 +149,7 @@ where rtype: RType, ) -> Result { match rtype { - RType::A => { - <&A>::parse_message_bytes(contents, start).map(Self::A) - } + RType::A => <&A>::parse_bytes(&contents[start..]).map(Self::A), RType::NS => { Ns::parse_message_bytes(contents, start).map(Self::Ns) } @@ -159,27 +160,27 @@ where Soa::parse_message_bytes(contents, start).map(Self::Soa) } RType::WKS => { - <&Wks>::parse_message_bytes(contents, start).map(Self::Wks) + <&Wks>::parse_bytes(&contents[start..]).map(Self::Wks) } RType::PTR => { Ptr::parse_message_bytes(contents, start).map(Self::Ptr) } RType::HINFO => { - HInfo::parse_message_bytes(contents, start).map(Self::HInfo) + HInfo::parse_bytes(&contents[start..]).map(Self::HInfo) } RType::MX => { Mx::parse_message_bytes(contents, start).map(Self::Mx) } RType::TXT => { - <&Txt>::parse_message_bytes(contents, start).map(Self::Txt) + <&Txt>::parse_bytes(&contents[start..]).map(Self::Txt) } RType::AAAA => { - <&Aaaa>::parse_message_bytes(contents, start).map(Self::Aaaa) + <&Aaaa>::parse_bytes(&contents[start..]).map(Self::Aaaa) } RType::OPT => { - <&Opt>::parse_message_bytes(contents, start).map(Self::Opt) + <&Opt>::parse_bytes(&contents[start..]).map(Self::Opt) } - _ => <&UnknownRecordData>::parse_message_bytes(contents, start) + _ => <&UnknownRecordData>::parse_bytes(&contents[start..]) .map(|data| Self::Unknown(rtype, data)), } } @@ -209,21 +210,26 @@ where //--- Building record data impl BuildIntoMessage for RecordData<'_, N> { - fn build_into_message(&self, builder: build::Builder<'_>) -> BuildResult { + fn build_into_message( + &self, + mut builder: build::Builder<'_>, + ) -> BuildResult { match self { - Self::A(r) => r.build_into_message(builder), - Self::Ns(r) => r.build_into_message(builder), - Self::CName(r) => r.build_into_message(builder), - Self::Soa(r) => r.build_into_message(builder), - Self::Wks(r) => r.build_into_message(builder), - Self::Ptr(r) => r.build_into_message(builder), - Self::HInfo(r) => r.build_into_message(builder), - Self::Mx(r) => r.build_into_message(builder), - Self::Txt(r) => r.build_into_message(builder), - Self::Aaaa(r) => r.build_into_message(builder), - Self::Opt(r) => r.build_into_message(builder), - Self::Unknown(_, r) => r.octets.build_into_message(builder), + Self::A(r) => builder.append_bytes(r.as_bytes())?, + Self::Ns(r) => return r.build_into_message(builder), + Self::CName(r) => return r.build_into_message(builder), + Self::Soa(r) => return r.build_into_message(builder), + Self::Wks(r) => builder.append_bytes(r.as_bytes())?, + Self::Ptr(r) => return r.build_into_message(builder), + Self::HInfo(r) => builder.append_built_bytes(r)?, + Self::Mx(r) => return r.build_into_message(builder), + Self::Txt(r) => builder.append_bytes(r.as_bytes())?, + Self::Aaaa(r) => builder.append_bytes(r.as_bytes())?, + Self::Opt(r) => builder.append_bytes(r.as_bytes())?, + Self::Unknown(_, r) => builder.append_bytes(r.as_bytes())?, } + + Ok(builder.commit()) } } From 646018dc2a59dd5e758ec1ae2507aef2dc6171dd Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 7 Mar 2025 11:40:28 +0100 Subject: [PATCH 147/167] [new_rdata] Incorporate DNSSEC record types --- src/new_base/record.rs | 24 ++++++++++ src/new_base/wire/size_prefixed.rs | 41 ++++++++++++---- src/new_rdata/dnssec/dnskey.rs | 61 ++++++++++++++++++++++- src/new_rdata/dnssec/ds.rs | 2 +- src/new_rdata/dnssec/mod.rs | 4 +- src/new_rdata/dnssec/nsec.rs | 4 +- src/new_rdata/dnssec/nsec3.rs | 33 ++++++++++++- src/new_rdata/mod.rs | 77 +++++++++++++++++++++++++++++- 8 files changed, 227 insertions(+), 19 deletions(-) diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 3af5fd191..9a2813b9a 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -238,6 +238,24 @@ impl RType { /// The type of an [`Opt`](crate::new_rdata::Opt) record. pub const OPT: Self = Self::new(41); + + /// The type of a [`Ds`](crate::new_rdata::Ds) record. + pub const DS: Self = Self::new(43); + + /// The type of an [`RRSig`](crate::new_rdata::RRSig) record. + pub const RRSIG: Self = Self::new(46); + + /// The type of an [`NSec`](crate::new_rdata::NSec) record. + pub const NSEC: Self = Self::new(47); + + /// The type of a [`DNSKey`](crate::new_rdata::DNSKey) record. + pub const DNSKEY: Self = Self::new(48); + + /// The type of an [`NSec3`](crate::new_rdata::NSec3) record. + pub const NSEC3: Self = Self::new(50); + + /// The type of an [`NSec3Param`](crate::new_rdata::NSec3Param) record. + pub const NSEC3PARAM: Self = Self::new(51); } //--- Conversion to and from 'u16' @@ -272,6 +290,12 @@ impl fmt::Debug for RType { Self::TXT => "RType::TXT", Self::AAAA => "RType::AAAA", Self::OPT => "RType::OPT", + Self::DS => "RType::DS", + Self::RRSIG => "RType::RRSIG", + Self::NSEC => "RType::NSEC", + Self::DNSKEY => "RType::DNSKEY", + Self::NSEC3 => "RType::NSEC3", + Self::NSEC3PARAM => "RType::NSEC3PARAM", _ => return write!(f, "RType({})", self.code), }) } diff --git a/src/new_base/wire/size_prefixed.rs b/src/new_base/wire/size_prefixed.rs index baec485cc..42c11e295 100644 --- a/src/new_base/wire/size_prefixed.rs +++ b/src/new_base/wire/size_prefixed.rs @@ -2,6 +2,7 @@ use core::{ borrow::{Borrow, BorrowMut}, + cmp::Ordering, fmt, ops::{Deref, DerefMut}, }; @@ -111,6 +112,38 @@ impl AsMut for SizePrefixed { } } +//--- Equality + +impl PartialEq for SizePrefixed { + fn eq(&self, other: &Self) -> bool { + self.data == other.data + } +} + +impl Eq for SizePrefixed {} + +//--- Ordering + +impl PartialOrd for SizePrefixed { + fn partial_cmp(&self, other: &Self) -> Option { + self.data.partial_cmp(&other.data) + } +} + +impl Ord for SizePrefixed { + fn cmp(&self, other: &Self) -> Ordering { + self.data.cmp(&other.data) + } +} + +//--- Formatting + +impl fmt::Debug for SizePrefixed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("SizePrefixed").field(&&self.data).finish() + } +} + //--- Parsing from DNS messages impl<'b, S, T: ParseMessageBytes<'b>> ParseMessageBytes<'b> @@ -328,11 +361,3 @@ where } unsafe impl AsBytes for SizePrefixed {} - -//--- Formatting - -impl fmt::Debug for SizePrefixed { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("SizePrefixed").field(&&self.data).finish() - } -} diff --git a/src/new_rdata/dnssec/dnskey.rs b/src/new_rdata/dnssec/dnskey.rs index 8a7b7c80a..ac77c4c57 100644 --- a/src/new_rdata/dnssec/dnskey.rs +++ b/src/new_rdata/dnssec/dnskey.rs @@ -1,3 +1,5 @@ +use core::fmt; + use domain_macros::*; use crate::new_base::wire::U16; @@ -6,8 +8,8 @@ use super::SecAlg; //----------- DNSKey --------------------------------------------------------- -/// A cryptographic key for signing DNS records. -#[derive(AsBytes, BuildBytes, ParseBytesByRef)] +/// A cryptographic key for DNS security. +#[derive(Debug, PartialEq, Eq, AsBytes, BuildBytes, ParseBytesByRef)] #[repr(C)] pub struct DNSKey { /// Flags describing the usage of the key. @@ -31,6 +33,8 @@ pub struct DNSKey { Clone, Default, Hash, + PartialEq, + Eq, AsBytes, BuildBytes, ParseBytes, @@ -42,3 +46,56 @@ pub struct DNSKey { pub struct DNSKeyFlags { inner: U16, } + +//--- Interaction + +impl DNSKeyFlags { + /// Get the specified flag bit. + fn get_flag(&self, pos: u32) -> bool { + self.inner.get() & (1 << pos) != 0 + } + + /// Set the specified flag bit. + fn set_flag(mut self, pos: u32, value: bool) -> Self { + self.inner &= !(1 << pos); + self.inner |= (value as u16) << pos; + self + } + + /// The raw flags bits. + pub fn bits(&self) -> u16 { + self.inner.get() + } + + /// Whether this key is used for signing DNS records. + pub fn is_zone_key(&self) -> bool { + self.get_flag(8) + } + + /// Make this key usable for signing DNS records. + pub fn set_zone_key(self, value: bool) -> Self { + self.set_flag(8, value) + } + + /// Whether external entities are expected to point to this key. + pub fn is_secure_entry_point(&self) -> bool { + self.get_flag(0) + } + + /// Expect external entities to point to this key. + pub fn set_secure_entry_point(self, value: bool) -> Self { + self.set_flag(0, value) + } +} + +//--- Formatting + +impl fmt::Debug for DNSKeyFlags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DNSKeyFlags") + .field("zone_key", &self.is_zone_key()) + .field("secure_entry_point", &self.is_secure_entry_point()) + .field("bits", &self.bits()) + .finish() + } +} diff --git a/src/new_rdata/dnssec/ds.rs b/src/new_rdata/dnssec/ds.rs index 270e5a080..369fc355d 100644 --- a/src/new_rdata/dnssec/ds.rs +++ b/src/new_rdata/dnssec/ds.rs @@ -9,7 +9,7 @@ use super::SecAlg; //----------- Ds ------------------------------------------------------------- /// The signing key for a delegated zone. -#[derive(Debug, AsBytes, BuildBytes, ParseBytesByRef)] +#[derive(Debug, PartialEq, Eq, AsBytes, BuildBytes, ParseBytesByRef)] #[repr(C)] pub struct Ds { /// The key tag of the signing key. diff --git a/src/new_rdata/dnssec/mod.rs b/src/new_rdata/dnssec/mod.rs index 705efc673..6182b615c 100644 --- a/src/new_rdata/dnssec/mod.rs +++ b/src/new_rdata/dnssec/mod.rs @@ -7,7 +7,7 @@ use domain_macros::*; //----------- Submodules ----------------------------------------------------- mod dnskey; -pub use dnskey::DNSKey; +pub use dnskey::{DNSKey, DNSKeyFlags}; mod rrsig; pub use rrsig::RRSig; @@ -16,7 +16,7 @@ mod nsec; pub use nsec::{NSec, TypeBitmaps}; mod nsec3; -pub use nsec3::{NSec3, NSec3Flags, NSec3HashAlg}; +pub use nsec3::{NSec3, NSec3Flags, NSec3HashAlg, NSec3Param}; mod ds; pub use ds::{DigestType, Ds}; diff --git a/src/new_rdata/dnssec/nsec.rs b/src/new_rdata/dnssec/nsec.rs index 682426ea7..6015354f4 100644 --- a/src/new_rdata/dnssec/nsec.rs +++ b/src/new_rdata/dnssec/nsec.rs @@ -11,7 +11,7 @@ use crate::new_base::{ //----------- NSec ----------------------------------------------------------- /// An indication of the non-existence of a set of DNS records (version 1). -#[derive(Clone, Debug, BuildBytes, ParseBytes)] +#[derive(Clone, Debug, PartialEq, Eq, BuildBytes, ParseBytes)] pub struct NSec<'a> { /// The name of the next existing DNS record. pub next: &'a Name, @@ -23,7 +23,7 @@ pub struct NSec<'a> { //----------- TypeBitmaps ---------------------------------------------------- /// A bitmap of DNS record types. -#[derive(AsBytes)] +#[derive(PartialEq, Eq, AsBytes, BuildBytes)] #[repr(transparent)] pub struct TypeBitmaps { octets: [u8], diff --git a/src/new_rdata/dnssec/nsec3.rs b/src/new_rdata/dnssec/nsec3.rs index 732766224..9981ded89 100644 --- a/src/new_rdata/dnssec/nsec3.rs +++ b/src/new_rdata/dnssec/nsec3.rs @@ -8,8 +8,8 @@ use super::TypeBitmaps; //----------- NSec3 ---------------------------------------------------------- -/// An indication of the non-existence of a DNS record (version 3). -#[derive(Clone, Debug, BuildBytes, ParseBytes)] +/// An indication of the non-existence of a set of DNS records (version 3). +#[derive(Clone, Debug, PartialEq, Eq, BuildBytes, ParseBytes)] pub struct NSec3<'a> { /// The algorithm used to hash names. pub algorithm: NSec3HashAlg, @@ -30,6 +30,33 @@ pub struct NSec3<'a> { pub types: &'a TypeBitmaps, } +//----------- NSec3Param ----------------------------------------------------- + +/// Parameters for computing [`NSec3`] records. +#[derive( + Debug, + PartialEq, + Eq, + AsBytes, + BuildBytes, + ParseBytesByRef, + SplitBytesByRef, +)] +#[repr(C)] +pub struct NSec3Param { + /// The algorithm used to hash names. + pub algorithm: NSec3HashAlg, + + /// Flags modifying the behaviour of the record. + pub flags: NSec3Flags, + + /// The number of iterations of the underlying hash function per name. + pub iterations: U16, + + /// The salt used to randomize the hash function. + pub salt: SizePrefixed, +} + //----------- NSec3HashAlg --------------------------------------------------- /// The hash algorithm used with [`NSec3`] records. @@ -80,6 +107,8 @@ impl fmt::Debug for NSec3HashAlg { Clone, Default, Hash, + PartialEq, + Eq, AsBytes, BuildBytes, ParseBytes, diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 807baf0fd..ed9680456 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -25,8 +25,8 @@ pub use edns::{EdnsOptionsIter, Opt}; mod dnssec; pub use dnssec::{ - DNSKey, DigestType, Ds, NSec, NSec3, NSec3Flags, NSec3HashAlg, RRSig, - SecAlg, + DNSKey, DNSKeyFlags, DigestType, Ds, NSec, NSec3, NSec3Flags, + NSec3HashAlg, NSec3Param, RRSig, SecAlg, }; //----------- RecordData ----------------------------------------------------- @@ -68,6 +68,24 @@ pub enum RecordData<'a, N> { /// Extended DNS options. Opt(&'a Opt), + /// The signing key of a delegated zone. + Ds(&'a Ds), + + /// A cryptographic signature on a DNS record set. + RRSig(RRSig<'a>), + + /// An indication of the non-existence of a set of DNS records (version 1). + NSec(NSec<'a>), + + /// A cryptographic key for DNS security. + DNSKey(&'a DNSKey), + + /// An indication of the non-existence of a set of DNS records (version 3). + NSec3(NSec3<'a>), + + /// Parameters for computing [`NSec3`] records. + NSec3Param(&'a NSec3Param), + /// Data for an unknown DNS record type. Unknown(RType, &'a UnknownRecordData), } @@ -89,6 +107,12 @@ impl RecordData<'_, N> { Self::Txt(..) => RType::TXT, Self::Aaaa(..) => RType::AAAA, Self::Opt(..) => RType::OPT, + Self::Ds(..) => RType::DS, + Self::RRSig(..) => RType::RRSIG, + Self::NSec(..) => RType::NSEC, + Self::DNSKey(..) => RType::DNSKEY, + Self::NSec3(..) => RType::NSEC3, + Self::NSec3Param(..) => RType::NSEC3PARAM, Self::Unknown(rtype, _) => *rtype, } } @@ -111,6 +135,12 @@ impl<'a, N> RecordData<'a, N> { Self::Txt(r) => RecordData::Txt(r), Self::Aaaa(r) => RecordData::Aaaa(r), Self::Opt(r) => RecordData::Opt(r), + Self::Ds(r) => RecordData::Ds(r), + Self::RRSig(r) => RecordData::RRSig(r), + Self::NSec(r) => RecordData::NSec(r), + Self::DNSKey(r) => RecordData::DNSKey(r), + Self::NSec3(r) => RecordData::NSec3(r), + Self::NSec3Param(r) => RecordData::NSec3Param(r), Self::Unknown(rt, rd) => RecordData::Unknown(rt, rd), } } @@ -132,6 +162,12 @@ impl<'a, N> RecordData<'a, N> { Self::Txt(r) => RecordData::Txt(r), Self::Aaaa(r) => RecordData::Aaaa(r), Self::Opt(r) => RecordData::Opt(r), + Self::Ds(r) => RecordData::Ds(r), + Self::RRSig(r) => RecordData::RRSig(r.clone()), + Self::NSec(r) => RecordData::NSec(r.clone()), + Self::DNSKey(r) => RecordData::DNSKey(r), + Self::NSec3(r) => RecordData::NSec3(r.clone()), + Self::NSec3Param(r) => RecordData::NSec3Param(r), Self::Unknown(rt, rd) => RecordData::Unknown(*rt, rd), } } @@ -180,6 +216,23 @@ where RType::OPT => { <&Opt>::parse_bytes(&contents[start..]).map(Self::Opt) } + RType::DS => <&Ds>::parse_bytes(&contents[start..]).map(Self::Ds), + RType::RRSIG => { + RRSig::parse_bytes(&contents[start..]).map(Self::RRSig) + } + RType::NSEC => { + NSec::parse_bytes(&contents[start..]).map(Self::NSec) + } + RType::DNSKEY => { + <&DNSKey>::parse_bytes(&contents[start..]).map(Self::DNSKey) + } + RType::NSEC3 => { + NSec3::parse_bytes(&contents[start..]).map(Self::NSec3) + } + RType::NSEC3PARAM => { + <&NSec3Param>::parse_bytes(&contents[start..]) + .map(Self::NSec3Param) + } _ => <&UnknownRecordData>::parse_bytes(&contents[start..]) .map(|data| Self::Unknown(rtype, data)), } @@ -201,6 +254,14 @@ where RType::TXT => <&Txt>::parse_bytes(bytes).map(Self::Txt), RType::AAAA => <&Aaaa>::parse_bytes(bytes).map(Self::Aaaa), RType::OPT => <&Opt>::parse_bytes(bytes).map(Self::Opt), + RType::DS => <&Ds>::parse_bytes(bytes).map(Self::Ds), + RType::RRSIG => RRSig::parse_bytes(bytes).map(Self::RRSig), + RType::NSEC => NSec::parse_bytes(bytes).map(Self::NSec), + RType::DNSKEY => <&DNSKey>::parse_bytes(bytes).map(Self::DNSKey), + RType::NSEC3 => NSec3::parse_bytes(bytes).map(Self::NSec3), + RType::NSEC3PARAM => { + <&NSec3Param>::parse_bytes(bytes).map(Self::NSec3Param) + } _ => <&UnknownRecordData>::parse_bytes(bytes) .map(|data| Self::Unknown(rtype, data)), } @@ -226,6 +287,12 @@ impl BuildIntoMessage for RecordData<'_, N> { Self::Txt(r) => builder.append_bytes(r.as_bytes())?, Self::Aaaa(r) => builder.append_bytes(r.as_bytes())?, Self::Opt(r) => builder.append_bytes(r.as_bytes())?, + Self::Ds(r) => builder.append_bytes(r.as_bytes())?, + Self::RRSig(r) => builder.append_built_bytes(r)?, + Self::NSec(r) => builder.append_built_bytes(r)?, + Self::DNSKey(r) => builder.append_bytes(r.as_bytes())?, + Self::NSec3(r) => builder.append_built_bytes(r)?, + Self::NSec3Param(r) => builder.append_bytes(r.as_bytes())?, Self::Unknown(_, r) => builder.append_bytes(r.as_bytes())?, } @@ -250,6 +317,12 @@ impl BuildBytes for RecordData<'_, N> { Self::Txt(r) => r.build_bytes(bytes), Self::Aaaa(r) => r.build_bytes(bytes), Self::Opt(r) => r.build_bytes(bytes), + Self::Ds(r) => r.build_bytes(bytes), + Self::RRSig(r) => r.build_bytes(bytes), + Self::NSec(r) => r.build_bytes(bytes), + Self::DNSKey(r) => r.build_bytes(bytes), + Self::NSec3(r) => r.build_bytes(bytes), + Self::NSec3Param(r) => r.build_bytes(bytes), Self::Unknown(_, r) => r.build_bytes(bytes), } } From 09d79b79db2974e1a14855f3a2f370f4e4f874e6 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 10 Mar 2025 16:54:40 +0100 Subject: [PATCH 148/167] [new_base/name/absolute] Add 'NameBuf' This will be used for zonefile parsing. --- src/new_base/name/absolute.rs | 289 +++++++++++++++++++++++++++++++++- src/new_base/name/mod.rs | 2 +- src/new_base/name/reversed.rs | 2 +- 3 files changed, 286 insertions(+), 7 deletions(-) diff --git a/src/new_base/name/absolute.rs b/src/new_base/name/absolute.rs index 5c0391b71..721bccef9 100644 --- a/src/new_base/name/absolute.rs +++ b/src/new_base/name/absolute.rs @@ -1,13 +1,18 @@ //! Absolute domain names. use core::{ + borrow::{Borrow, BorrowMut}, fmt, hash::{Hash, Hasher}, + ops::{Deref, DerefMut}, }; -use domain_macros::{AsBytes, BuildBytes}; +use domain_macros::*; -use crate::new_base::wire::{ParseBytes, ParseError, SplitBytes}; +use crate::new_base::{ + parse::{ParseMessageBytes, SplitMessageBytes}, + wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, +}; use super::LabelIter; @@ -184,8 +189,282 @@ impl fmt::Display for Name { impl fmt::Debug for Name { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Name") - .field(&format_args!("{}", self)) - .finish() + write!(f, "Name({})", self) + } +} + +//----------- NameBuf -------------------------------------------------------- + +/// A 256-byte buffer containing a [`Name`]. +#[derive(Clone)] +#[repr(C)] // make layout compatible with '[u8; 256]' +pub struct NameBuf { + /// The size of the encoded name. + size: u8, + + /// The buffer containing the [`Name`]. + buffer: [u8; 255], +} + +//--- Construction + +impl NameBuf { + /// Construct an empty, invalid buffer. + const fn empty() -> Self { + Self { + size: 0, + buffer: [0; 255], + } + } + + /// Copy a [`Name`] into a buffer. + pub fn copy_from(name: &Name) -> Self { + let mut buffer = [0u8; 255]; + buffer[..name.len()].copy_from_slice(name.as_bytes()); + Self { + size: name.len() as u8, + buffer, + } + } +} + +//--- Parsing from DNS messages + +impl<'a> SplitMessageBytes<'a> for NameBuf { + fn split_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result<(Self, usize), ParseError> { + // NOTE: The input may be controlled by an attacker. Compression + // pointers can be arranged to cause loops or to access every byte in + // the message in random order. Instead of performing complex loop + // detection, which would probably perform allocations, we simply + // disallow a name to point to data _after_ it. Standard name + // compressors will never generate such pointers. + + let mut buffer = Self::empty(); + + // Perform the first iteration early, to catch the end of the name. + let bytes = contents.get(start..).ok_or(ParseError)?; + let (mut pointer, rest) = parse_segment(bytes, &mut buffer)?; + let orig_end = contents.len() - rest.len(); + + // Traverse compression pointers. + let mut old_start = start; + while let Some(start) = pointer.map(usize::from) { + // Ensure the referenced position comes earlier. + if start >= old_start { + return Err(ParseError); + } + + // Keep going, from the referenced position. + let start = start.checked_sub(12).ok_or(ParseError)?; + let bytes = contents.get(start..).ok_or(ParseError)?; + (pointer, _) = parse_segment(bytes, &mut buffer)?; + old_start = start; + continue; + } + + // Stop and return the original end. + // NOTE: 'buffer' is now well-formed because we only stop when we + // reach a root label (which has been appended into it). + Ok((buffer, orig_end)) + } +} + +impl<'a> ParseMessageBytes<'a> for NameBuf { + fn parse_message_bytes( + contents: &'a [u8], + start: usize, + ) -> Result { + // See 'split_from_message()' for details. The only differences are + // in the range of the first iteration, and the check that the first + // iteration exactly covers the input range. + + let mut buffer = Self::empty(); + + // Perform the first iteration early, to catch the end of the name. + let bytes = contents.get(start..).ok_or(ParseError)?; + let (mut pointer, rest) = parse_segment(bytes, &mut buffer)?; + + if !rest.is_empty() { + // The name didn't reach the end of the input range, fail. + return Err(ParseError); + } + + // Traverse compression pointers. + let mut old_start = start; + while let Some(start) = pointer.map(usize::from) { + // Ensure the referenced position comes earlier. + if start >= old_start { + return Err(ParseError); + } + + // Keep going, from the referenced position. + let start = start.checked_sub(12).ok_or(ParseError)?; + let bytes = contents.get(start..).ok_or(ParseError)?; + (pointer, _) = parse_segment(bytes, &mut buffer)?; + old_start = start; + continue; + } + + // NOTE: 'buffer' is now well-formed because we only stop when we + // reach a root label (which has been appended into it). + Ok(buffer) + } +} + +/// Parse an encoded and potentially-compressed domain name, without +/// following any compression pointer. +fn parse_segment<'a>( + mut bytes: &'a [u8], + buffer: &mut NameBuf, +) -> Result<(Option, &'a [u8]), ParseError> { + loop { + match *bytes { + [0, ref rest @ ..] => { + // Found the root, stop. + buffer.append_bytes(&[0u8]); + return Ok((None, rest)); + } + + [l, ..] if l < 64 => { + // This looks like a regular label. + + if bytes.len() < 1 + l as usize { + // The input doesn't contain the whole label. + return Err(ParseError); + } else if 255 - buffer.size < 2 + l { + // The output name would exceed 254 bytes (this isn't + // the root label, so it can't fill the 255th byte). + return Err(ParseError); + } + + let (label, rest) = bytes.split_at(1 + l as usize); + buffer.append_bytes(label); + bytes = rest; + } + + [hi, lo, ref rest @ ..] if hi >= 0xC0 => { + let pointer = u16::from_be_bytes([hi, lo]); + + // NOTE: We don't verify the pointer here, that's left to + // the caller (since they have to actually use it). + return Ok((Some(pointer & 0x3FFF), rest)); + } + + _ => return Err(ParseError), + } + } +} + +//--- Parsing from bytes + +impl<'a> SplitBytes<'a> for NameBuf { + fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> { + <&Name>::split_bytes(bytes) + .map(|(name, rest)| (NameBuf::copy_from(name), rest)) + } +} + +impl<'a> ParseBytes<'a> for NameBuf { + fn parse_bytes(bytes: &'a [u8]) -> Result { + <&Name>::parse_bytes(bytes).map(NameBuf::copy_from) + } +} + +//--- Building into byte strings + +impl BuildBytes for NameBuf { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + (**self).build_bytes(bytes) + } +} + +//--- Interaction + +impl NameBuf { + /// Append bytes to this buffer. + /// + /// This is an internal convenience function used while building buffers. + fn append_bytes(&mut self, bytes: &[u8]) { + self.buffer[self.size as usize..][..bytes.len()] + .copy_from_slice(bytes); + self.size += bytes.len() as u8; + } +} + +//--- Access to the underlying 'Name' + +impl Deref for NameBuf { + type Target = Name; + + fn deref(&self) -> &Self::Target { + let name = &self.buffer[..self.size as usize]; + // SAFETY: A 'NameBuf' always contains a valid 'Name'. + unsafe { Name::from_bytes_unchecked(name) } + } +} + +impl DerefMut for NameBuf { + fn deref_mut(&mut self) -> &mut Self::Target { + let name = &mut self.buffer[..self.size as usize]; + // SAFETY: A 'NameBuf' always contains a valid 'Name'. + unsafe { Name::from_bytes_unchecked_mut(name) } + } +} + +impl Borrow for NameBuf { + fn borrow(&self) -> &Name { + self + } +} + +impl BorrowMut for NameBuf { + fn borrow_mut(&mut self) -> &mut Name { + self + } +} + +impl AsRef for NameBuf { + fn as_ref(&self) -> &Name { + self + } +} + +impl AsMut for NameBuf { + fn as_mut(&mut self) -> &mut Name { + self + } +} + +//--- Forwarding equality, hashing, and formatting + +impl PartialEq for NameBuf { + fn eq(&self, that: &Self) -> bool { + **self == **that + } +} + +impl Eq for NameBuf {} + +impl Hash for NameBuf { + fn hash(&self, state: &mut H) { + (**self).hash(state) + } +} + +impl fmt::Display for NameBuf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl fmt::Debug for NameBuf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) } } diff --git a/src/new_base/name/mod.rs b/src/new_base/name/mod.rs index c74aeab3f..c1cc2df2d 100644 --- a/src/new_base/name/mod.rs +++ b/src/new_base/name/mod.rs @@ -18,7 +18,7 @@ mod label; pub use label::{Label, LabelBuf, LabelIter}; mod absolute; -pub use absolute::Name; +pub use absolute::{Name, NameBuf}; mod reversed; pub use reversed::{RevName, RevNameBuf}; diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 2edf9a75b..188660154 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -229,7 +229,7 @@ pub struct RevNameBuf { impl RevNameBuf { /// Construct an empty, invalid buffer. - fn empty() -> Self { + const fn empty() -> Self { Self { offset: 255, buffer: [0; 255], From 3c630489ef90398abdc8b7c2a2eea4f9807b3797 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 12 Mar 2025 13:27:56 +0100 Subject: [PATCH 149/167] [new_base] Add 'bumpalo' and impl 'clone_to_bump()' 'bumpalo' is useful for allocating unsized data types like record data. It is currently used in the 'new-zonefile' branch as a buffer space for records to be parsed into. 'clone_to_bump()' is necessary in order to deep-copy record data into a new bump allocator. --- Cargo.lock | 1 + Cargo.toml | 2 ++ src/new_base/charstr.rs | 13 ++++++++ src/new_base/name/absolute.rs | 13 ++++++++ src/new_base/name/reversed.rs | 13 ++++++++ src/new_base/wire/size_prefixed.rs | 21 ++++++++++-- src/new_rdata/basic/hinfo.rs | 13 ++++++++ src/new_rdata/basic/txt.rs | 52 +++++++++++++++++++++--------- src/new_rdata/basic/wks.rs | 15 +++++++++ src/new_rdata/dnssec/dnskey.rs | 13 ++++++++ src/new_rdata/dnssec/ds.rs | 15 +++++++++ src/new_rdata/dnssec/nsec.rs | 28 ++++++++++++++++ src/new_rdata/dnssec/nsec3.rs | 32 ++++++++++++++++++ src/new_rdata/dnssec/rrsig.rs | 14 ++++++++ src/new_rdata/edns.rs | 15 +++++++++ src/new_rdata/mod.rs | 50 ++++++++++++++++++++++++++++ 16 files changed, 291 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a0eb76973..d9c2d9782 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -232,6 +232,7 @@ version = "0.10.3" dependencies = [ "arbitrary", "arc-swap", + "bumpalo", "bytes", "chrono", "domain-macros", diff --git a/Cargo.toml b/Cargo.toml index a9321cd4c..411cb7a42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ license = "BSD-3-Clause" domain-macros = { path = "./macros", version = "0.10.3" } arbitrary = { version = "1.4.1", optional = true, features = ["derive"] } +bumpalo = { version = "3.12", optional = true } octseq = { version = "0.5.2", default-features = false } time = { version = "0.3.1", default-features = false } rand = { version = "0.8", optional = true } @@ -54,6 +55,7 @@ tracing-subscriber = { version = "0.3.18", optional = true, features = ["env-fil default = ["std", "rand"] # Support for libraries +bumpalo = ["dep:bumpalo"] bytes = ["dep:bytes", "octseq/bytes"] heapless = ["dep:heapless", "octseq/heapless"] serde = ["dep:serde", "octseq/serde"] diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index adf8e6024..3fd617e90 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -17,6 +17,19 @@ pub struct CharStr { pub octets: [u8], } +//--- Interaction + +impl CharStr { + /// Copy this into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + #[allow(clippy::mut_from_ref)] // using a memory allocator + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { + let octets = bump.alloc_slice_copy(&self.octets); + // SAFETY: 'CharStr' is 'repr(transparent)' to '[u8]'. + unsafe { core::mem::transmute::<&mut [u8], &mut CharStr>(octets) } + } +} + //--- Parsing from DNS messages impl<'a> SplitMessageBytes<'a> for &'a CharStr { diff --git a/src/new_base/name/absolute.rs b/src/new_base/name/absolute.rs index 721bccef9..c06f41317 100644 --- a/src/new_base/name/absolute.rs +++ b/src/new_base/name/absolute.rs @@ -95,6 +95,19 @@ impl Name { } } +//--- Interaction + +impl Name { + /// Copy this into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + #[allow(clippy::mut_from_ref)] // using a memory allocator + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { + let bytes = bump.alloc_slice_copy(self.as_bytes()); + // SAFETY: 'AsBytes' is a transmute, so we can transmute back. + unsafe { core::mem::transmute::<&mut [u8], &mut Self>(bytes) } + } +} + //--- Parsing from bytes impl<'a> ParseBytes<'a> for &'a Name { diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 188660154..e9f7c0318 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -105,6 +105,19 @@ impl RevName { } } +//--- Interaction + +impl RevName { + /// Copy this into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + #[allow(clippy::mut_from_ref)] // using a memory allocator + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { + let octets = bump.alloc_slice_copy(self.as_bytes()); + // SAFETY: 'RevName' is 'repr(transparent)' to '[u8]'. + unsafe { core::mem::transmute::<&mut [u8], &mut Self>(octets) } + } +} + //--- Building into DNS messages impl BuildIntoMessage for RevName { diff --git a/src/new_base/wire/size_prefixed.rs b/src/new_base/wire/size_prefixed.rs index 42c11e295..b9583cf2c 100644 --- a/src/new_base/wire/size_prefixed.rs +++ b/src/new_base/wire/size_prefixed.rs @@ -25,7 +25,7 @@ use super::{ /// options). When serialized as bytes, the inner value is prefixed with an /// integer (often a [`U16`](super::U16)) indicating the length of the inner /// value in bytes. -#[derive(Copy, Clone)] +#[derive(Copy, Clone, AsBytes)] #[repr(C)] pub struct SizePrefixed { /// The size prefix (needed for 'ParseBytesByRef' / 'AsBytes'). @@ -61,6 +61,23 @@ where } } +//--- Interaction + +impl SizePrefixed +where + S: AsBytes + SplitBytesByRef + Copy + TryFrom + TryInto, + T: AsBytes + ParseBytesByRef, +{ + /// Copy this into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + #[allow(clippy::mut_from_ref)] // using a memory allocator + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { + let bytes = bump.alloc_slice_copy(self.as_bytes()); + // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. + unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } + } +} + //--- Conversion from the inner data impl From for SizePrefixed @@ -359,5 +376,3 @@ where Ok(rest) } } - -unsafe impl AsBytes for SizePrefixed {} diff --git a/src/new_rdata/basic/hinfo.rs b/src/new_rdata/basic/hinfo.rs index 0151c97d3..ea6e4ad93 100644 --- a/src/new_rdata/basic/hinfo.rs +++ b/src/new_rdata/basic/hinfo.rs @@ -13,3 +13,16 @@ pub struct HInfo<'a> { /// The OS type. pub os: &'a CharStr, } + +//--- Interaction + +impl HInfo<'_> { + /// Copy referenced data into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> HInfo<'r> { + HInfo { + cpu: self.cpu.clone_to_bump(bump), + os: self.os.clone_to_bump(bump), + } + } +} diff --git a/src/new_rdata/basic/txt.rs b/src/new_rdata/basic/txt.rs index a7188e9b0..ff8830a0d 100644 --- a/src/new_rdata/basic/txt.rs +++ b/src/new_rdata/basic/txt.rs @@ -4,8 +4,7 @@ use domain_macros::*; use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, - parse::ParseMessageBytes, - wire::{ParseBytes, ParseError, SplitBytes}, + wire::{ParseBytesByRef, ParseError, SplitBytes}, CharStr, }; @@ -24,6 +23,17 @@ pub struct Txt { //--- Interaction impl Txt { + /// Copy this into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + #[allow(clippy::mut_from_ref)] // using a memory allocator + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { + use crate::new_base::wire::AsBytes; + + let bytes = bump.alloc_slice_copy(self.as_bytes()); + // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. + unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } + } + /// Iterate over the [`CharStr`]s in this record. pub fn iter(&self) -> impl Iterator + '_ { // NOTE: A TXT record always has at least one 'CharStr' within. @@ -39,17 +49,6 @@ impl Txt { } } -//--- Parsing from DNS messages - -impl<'a> ParseMessageBytes<'a> for &'a Txt { - fn parse_message_bytes( - contents: &'a [u8], - start: usize, - ) -> Result { - Self::parse_bytes(&contents[start..]) - } -} - //--- Building into DNS messages impl BuildIntoMessage for Txt { @@ -60,16 +59,37 @@ impl BuildIntoMessage for Txt { //--- Parsing from bytes -impl<'a> ParseBytes<'a> for &'a Txt { - fn parse_bytes(bytes: &'a [u8]) -> Result { +impl Txt { + /// Validate the given bytes as a 'Txt'. + fn validate_bytes(bytes: &[u8]) -> Result<(), ParseError> { // NOTE: The input must contain at least one 'CharStr'. let (_, mut rest) = <&CharStr>::split_bytes(bytes)?; while !rest.is_empty() { (_, rest) = <&CharStr>::split_bytes(rest)?; } + Ok(()) + } +} + +// SAFETY: The implementations of 'parse_bytes_by_{ref,mut}()' always parse +// the entirety of the input on success, satisfying the safety requirements. +unsafe impl ParseBytesByRef for Txt { + fn parse_bytes_by_ref(bytes: &[u8]) -> Result<&Self, ParseError> { + Self::validate_bytes(bytes)?; + + // SAFETY: 'Txt' is 'repr(transparent)' to '[u8]'. + Ok(unsafe { core::mem::transmute::<&[u8], &Self>(bytes) }) + } + + fn parse_bytes_by_mut(bytes: &mut [u8]) -> Result<&mut Self, ParseError> { + Self::validate_bytes(bytes)?; // SAFETY: 'Txt' is 'repr(transparent)' to '[u8]'. - Ok(unsafe { core::mem::transmute::<&'a [u8], Self>(bytes) }) + Ok(unsafe { core::mem::transmute::<&mut [u8], &mut Self>(bytes) }) + } + + fn ptr_with_address(&self, addr: *const ()) -> *const Self { + self.content.ptr_with_address(addr) as *const Self } } diff --git a/src/new_rdata/basic/wks.rs b/src/new_rdata/basic/wks.rs index 9e29d4133..2e7418360 100644 --- a/src/new_rdata/basic/wks.rs +++ b/src/new_rdata/basic/wks.rs @@ -22,6 +22,21 @@ pub struct Wks { pub ports: [u8], } +//--- Interaction + +impl Wks { + /// Copy this into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + #[allow(clippy::mut_from_ref)] // using a memory allocator + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { + use crate::new_base::wire::ParseBytesByRef; + + let bytes = bump.alloc_slice_copy(self.as_bytes()); + // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. + unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } + } +} + //--- Formatting impl fmt::Debug for Wks { diff --git a/src/new_rdata/dnssec/dnskey.rs b/src/new_rdata/dnssec/dnskey.rs index ac77c4c57..0e9c391a9 100644 --- a/src/new_rdata/dnssec/dnskey.rs +++ b/src/new_rdata/dnssec/dnskey.rs @@ -25,6 +25,19 @@ pub struct DNSKey { pub key: [u8], } +impl DNSKey { + /// Copy this into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + #[allow(clippy::mut_from_ref)] // using a memory allocator + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { + use crate::new_base::wire::{AsBytes, ParseBytesByRef}; + + let bytes = bump.alloc_slice_copy(self.as_bytes()); + // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. + unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } + } +} + //----------- DNSKeyFlags ---------------------------------------------------- /// Flags describing a [`DNSKey`]. diff --git a/src/new_rdata/dnssec/ds.rs b/src/new_rdata/dnssec/ds.rs index 369fc355d..ed92dd1f4 100644 --- a/src/new_rdata/dnssec/ds.rs +++ b/src/new_rdata/dnssec/ds.rs @@ -25,6 +25,21 @@ pub struct Ds { pub digest: [u8], } +//--- Interaction + +impl Ds { + /// Copy this into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + #[allow(clippy::mut_from_ref)] // using a memory allocator + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { + use crate::new_base::wire::{AsBytes, ParseBytesByRef}; + + let bytes = bump.alloc_slice_copy(self.as_bytes()); + // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. + unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } + } +} + //----------- DigestType ----------------------------------------------------- /// A cryptographic digest algorithm. diff --git a/src/new_rdata/dnssec/nsec.rs b/src/new_rdata/dnssec/nsec.rs index 6015354f4..b7b47191d 100644 --- a/src/new_rdata/dnssec/nsec.rs +++ b/src/new_rdata/dnssec/nsec.rs @@ -20,6 +20,19 @@ pub struct NSec<'a> { pub types: &'a TypeBitmaps, } +//--- Interaction + +impl NSec<'_> { + /// Copy referenced data into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> NSec<'r> { + NSec { + next: self.next.clone_to_bump(bump), + types: self.types.clone_to_bump(bump), + } + } +} + //----------- TypeBitmaps ---------------------------------------------------- /// A bitmap of DNS record types. @@ -56,6 +69,21 @@ impl TypeBitmaps { } } +//--- Interaction + +impl TypeBitmaps { + /// Copy this into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + #[allow(clippy::mut_from_ref)] // using a memory allocator + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { + use crate::new_base::wire::{AsBytes, ParseBytesByRef}; + + let bytes = bump.alloc_slice_copy(self.as_bytes()); + // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. + unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } + } +} + //--- Formatting impl fmt::Debug for TypeBitmaps { diff --git a/src/new_rdata/dnssec/nsec3.rs b/src/new_rdata/dnssec/nsec3.rs index 9981ded89..f8b03b049 100644 --- a/src/new_rdata/dnssec/nsec3.rs +++ b/src/new_rdata/dnssec/nsec3.rs @@ -30,6 +30,23 @@ pub struct NSec3<'a> { pub types: &'a TypeBitmaps, } +//--- Interaction + +impl NSec3<'_> { + /// Copy referenced data into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> NSec3<'r> { + NSec3 { + algorithm: self.algorithm, + flags: self.flags, + iterations: self.iterations, + salt: self.salt.clone_to_bump(bump), + next: self.next.clone_to_bump(bump), + types: self.types.clone_to_bump(bump), + } + } +} + //----------- NSec3Param ----------------------------------------------------- /// Parameters for computing [`NSec3`] records. @@ -57,6 +74,21 @@ pub struct NSec3Param { pub salt: SizePrefixed, } +//--- Interaction + +impl NSec3Param { + /// Copy this into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + #[allow(clippy::mut_from_ref)] // using a memory allocator + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { + use crate::new_base::wire::{AsBytes, ParseBytesByRef}; + + let bytes = bump.alloc_slice_copy(self.as_bytes()); + // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. + unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } + } +} + //----------- NSec3HashAlg --------------------------------------------------- /// The hash algorithm used with [`NSec3`] records. diff --git a/src/new_rdata/dnssec/rrsig.rs b/src/new_rdata/dnssec/rrsig.rs index ea9a6f5c6..60f9fddb3 100644 --- a/src/new_rdata/dnssec/rrsig.rs +++ b/src/new_rdata/dnssec/rrsig.rs @@ -36,3 +36,17 @@ pub struct RRSig<'a> { /// The serialized cryptographic signature. pub signature: &'a [u8], } + +//--- Interaction + +impl RRSig<'_> { + /// Copy referenced data into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> RRSig<'r> { + RRSig { + signer: self.signer.clone_to_bump(bump), + signature: bump.alloc_slice_copy(self.signature), + ..self.clone() + } + } +} diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index 46bba4ed3..d03b739a4 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -32,6 +32,21 @@ impl Opt { } } +//--- Interaction + +impl Opt { + /// Copy this into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + #[allow(clippy::mut_from_ref)] // using a memory allocator + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { + use crate::new_base::wire::{AsBytes, ParseBytesByRef}; + + let bytes = bump.alloc_slice_copy(self.as_bytes()); + // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. + unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } + } +} + //--- Formatting impl fmt::Debug for Opt { diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index ed9680456..540cbed6e 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -171,6 +171,41 @@ impl<'a, N> RecordData<'a, N> { Self::Unknown(rt, rd) => RecordData::Unknown(*rt, rd), } } + + /// Copy referenced data into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + pub fn clone_to_bump<'r>( + &self, + bump: &'r bumpalo::Bump, + ) -> RecordData<'r, N> + where + N: Clone, + { + match self { + Self::A(&r) => RecordData::A(bump.alloc(r)), + Self::Ns(r) => RecordData::Ns(r.clone()), + Self::CName(r) => RecordData::CName(r.clone()), + Self::Soa(r) => RecordData::Soa(r.clone()), + Self::Wks(r) => RecordData::Wks(r.clone_to_bump(bump)), + Self::Ptr(r) => RecordData::Ptr(r.clone()), + Self::HInfo(r) => RecordData::HInfo(r.clone_to_bump(bump)), + Self::Mx(r) => RecordData::Mx(r.clone()), + Self::Txt(r) => RecordData::Txt(r.clone_to_bump(bump)), + Self::Aaaa(&r) => RecordData::Aaaa(bump.alloc(r)), + Self::Opt(r) => RecordData::Opt(r.clone_to_bump(bump)), + Self::Ds(r) => RecordData::Ds(r.clone_to_bump(bump)), + Self::RRSig(r) => RecordData::RRSig(r.clone_to_bump(bump)), + Self::NSec(r) => RecordData::NSec(r.clone_to_bump(bump)), + Self::DNSKey(r) => RecordData::DNSKey(r.clone_to_bump(bump)), + Self::NSec3(r) => RecordData::NSec3(r.clone_to_bump(bump)), + Self::NSec3Param(r) => { + RecordData::NSec3Param(r.clone_to_bump(bump)) + } + Self::Unknown(rt, rd) => { + RecordData::Unknown(*rt, rd.clone_to_bump(bump)) + } + } + } } //--- Parsing record data @@ -337,3 +372,18 @@ pub struct UnknownRecordData { /// The unparsed option data. pub octets: [u8], } + +//--- Interaction + +impl UnknownRecordData { + /// Copy this into the given [`Bump`] allocator. + #[cfg(feature = "bumpalo")] + #[allow(clippy::mut_from_ref)] // using a memory allocator + pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { + use crate::new_base::wire::{AsBytes, ParseBytesByRef}; + + let bytes = bump.alloc_slice_copy(self.as_bytes()); + // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. + unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } + } +} From 55815acb04f0d447a5d4c2453a45efbb7b4326fd Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 17 Mar 2025 17:53:21 +0100 Subject: [PATCH 150/167] [utils] Add 'UnsizedClone' and 'CloneFrom' These are important for improving DST ergonomics; they make it much easier to copy around unsized types, which 'domain' now has many of. --- src/utils/mod.rs | 245 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 245 insertions(+) diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 584724715..262ad3e58 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,8 +1,253 @@ //! Various utility modules. +use core::{ops::Deref, ptr::addr_of_mut}; + +#[cfg(feature = "std")] +use std::{alloc::Layout, boxed::Box, vec::Vec}; + pub mod base16; pub mod base32; pub mod base64; #[cfg(feature = "net")] pub(crate) mod config; + +//----------- UnsizedClone --------------------------------------------------- + +/// The ability to clone a (possibly unsized) value. +/// +/// # Safety +/// +/// If `unsized_clone()` returns successfully (i.e. without panicking), `dst` +/// is initialized to a valid instance of `Self`. +pub unsafe trait UnsizedClone { + /// Clone this value into the given space. + /// + /// # Safety + /// + /// `dst` must be allocated with the same size and alignment as `self`. + unsafe fn unsized_clone(&self, dst: *mut ()); + + /// Change the address of a pointer to [`Self`]. + /// + /// When [`Self`] is used as the last field in a type that also implements + /// [`ParseBytesByRef`], it may be dynamically sized, and so a pointer (or + /// reference) to it may include additional metadata. This metadata is + /// included verbatim in any reference/pointer to the containing type. + /// + /// When the containing type implements [`UnsizedClone`], it needs to + /// construct a reference/pointer to itself, which includes this metadata. + /// Rust does not currently offer a general way to extract this metadata + /// or pair it with another address, so this function is necessary. The + /// caller can construct a reference to [`Self`], then change its address + /// to point to the containing type, then cast that pointer to the right + /// type. + /// + /// # Implementing + /// + /// Most users will derive [`UnsizedClone`] and so don't need to worry + /// about this. For manual implementations: + /// + /// In the future, an adequate default implementation for this function + /// may be provided. Until then, it should be implemented using one of + /// the following expressions: + /// + /// ```text + /// fn ptr_with_address( + /// &self, + /// addr: *mut (), + /// ) -> *const Self { + /// // If 'Self' is Sized: + /// addr.cast::() + /// + /// // If 'Self' is an aggregate whose last field is 'last': + /// self.last.ptr_with_address(addr) as *mut Self + /// } + /// ``` + /// + /// # Invariants + /// + /// For the statement `let result = Self::ptr_with_address(ptr, addr);`: + /// + /// - `result as usize == addr as usize`. + /// - `core::ptr::metadata(result) == core::ptr::metadata(ptr)`. + fn ptr_with_address(&self, addr: *mut ()) -> *mut Self; +} + +macro_rules! impl_primitive_unsized_clone { + ($type:ty) => { + unsafe impl UnsizedClone for $type { + unsafe fn unsized_clone(&self, dst: *mut ()) { + let this = self.clone(); + unsafe { dst.cast::().write(this) }; + } + + fn ptr_with_address(&self, addr: *mut ()) -> *mut Self { + addr.cast::() + } + } + }; +} + +impl_primitive_unsized_clone!(bool); +impl_primitive_unsized_clone!(char); + +impl_primitive_unsized_clone!(u8); +impl_primitive_unsized_clone!(u16); +impl_primitive_unsized_clone!(u32); +impl_primitive_unsized_clone!(u64); +impl_primitive_unsized_clone!(u128); +impl_primitive_unsized_clone!(usize); + +impl_primitive_unsized_clone!(i8); +impl_primitive_unsized_clone!(i16); +impl_primitive_unsized_clone!(i32); +impl_primitive_unsized_clone!(i64); +impl_primitive_unsized_clone!(i128); +impl_primitive_unsized_clone!(isize); + +impl_primitive_unsized_clone!(f32); +impl_primitive_unsized_clone!(f64); + +unsafe impl UnsizedClone for &T { + unsafe fn unsized_clone(&self, dst: *mut ()) { + unsafe { dst.cast::().write(*self) }; + } + + fn ptr_with_address(&self, addr: *mut ()) -> *mut Self { + addr.cast() + } +} + +unsafe impl UnsizedClone for str { + unsafe fn unsized_clone(&self, dst: *mut ()) { + unsafe { + self.as_bytes() + .as_ptr() + .copy_to_nonoverlapping(dst.cast(), self.len()); + } + } + + fn ptr_with_address(&self, addr: *mut ()) -> *mut Self { + // NOTE: The Rust Reference indicates that 'str' has the same layout + // as '[u8]' [1]. This is also the most natural layout for it. Since + // there's no way to construct a '*const str' from raw parts, we will + // just construct a raw slice and transmute it. + // + // [1]: https://doc.rust-lang.org/reference/type-layout.html#str-layout + + self.as_bytes().ptr_with_address(addr) as *mut Self + } +} + +unsafe impl UnsizedClone for [T] { + unsafe fn unsized_clone(&self, dst: *mut ()) { + /// A drop guard. + struct Guard { + /// The slice being written into. + dst: *mut T, + + /// The number of initialized elements. + num: usize, + } + + impl Drop for Guard { + fn drop(&mut self) { + // Drop all initialized elements. + unsafe { + core::ptr::slice_from_raw_parts_mut(self.dst, self.num) + .drop_in_place() + }; + } + } + + let mut guard = Guard { + dst: dst.cast::(), + num: 0, + }; + + for elem in self { + let elem = elem.clone(); + unsafe { + guard.dst.add(guard.num).write(elem); + } + guard.num += 1; + } + } + + fn ptr_with_address(&self, addr: *mut ()) -> *mut Self { + core::ptr::slice_from_raw_parts_mut(addr.cast(), self.len()) + } +} + +unsafe impl UnsizedClone for [T; N] { + unsafe fn unsized_clone(&self, dst: *mut ()) { + let this = self.clone(); + unsafe { dst.cast::().write(this) }; + } + + fn ptr_with_address(&self, addr: *mut ()) -> *mut Self { + addr as *mut Self + } +} + +impl_primitive_unsized_clone!(()); + +macro_rules! impl_unsized_clone_tuple { + ($last_idx:tt: $last_type:ident $(, $idx:tt: $type:ident)*) => { + unsafe impl + <$($type: Clone,)* $last_type: ?Sized + UnsizedClone> + UnsizedClone for ($($type,)* $last_type,) { + unsafe fn unsized_clone(&self, dst: *mut ()) { + let dst: *mut Self = self.ptr_with_address(dst); + unsafe { + $(addr_of_mut!((*dst).$idx).write(self.$idx.clone());)* + self.$last_idx.unsized_clone(addr_of_mut!((*dst).$last_idx).cast()); + } + } + + fn ptr_with_address(&self, addr: *mut ()) -> *mut Self { + self.$last_idx.ptr_with_address(addr) as *mut Self + } + } + }; +} + +impl_unsized_clone_tuple!(0: A); +impl_unsized_clone_tuple!(1: B, 0: A); +impl_unsized_clone_tuple!(2: C, 0: A, 1: B); +impl_unsized_clone_tuple!(3: D, 0: A, 1: B, 2: C); +impl_unsized_clone_tuple!(4: E, 0: A, 1: B, 2: C, 3: D); +impl_unsized_clone_tuple!(5: F, 0: A, 1: B, 2: C, 3: D, 4: E); +impl_unsized_clone_tuple!(6: G, 0: A, 1: B, 2: C, 3: D, 4: E, 5: F); +impl_unsized_clone_tuple!(7: H, 0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G); +impl_unsized_clone_tuple!(8: I, 0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H); +impl_unsized_clone_tuple!(9: J, 0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I); +impl_unsized_clone_tuple!(10: K, 0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J); +impl_unsized_clone_tuple!(11: L, 0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K); + +//----------- CloneFrom ------------------------------------------------------ + +/// A container that can be built by cloning. +pub trait CloneFrom: Sized + Deref { + /// Clone a value into this container. + fn clone_from(value: &Self::Target) -> Self; +} + +#[cfg(feature = "std")] +impl CloneFrom for Box { + fn clone_from(value: &Self::Target) -> Self { + let layout = Layout::for_value(value); + let ptr = unsafe { std::alloc::alloc(layout) }; + unsafe { value.unsized_clone(ptr.cast()) }; + let ptr = value.ptr_with_address(ptr.cast()); + unsafe { Box::from_raw(ptr) } + } +} + +#[cfg(feature = "std")] +impl CloneFrom for Vec { + fn clone_from(value: &Self::Target) -> Self { + value.into() + } +} From d14d365abf695e0227de38d6a97e060d36356227 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Mon, 17 Mar 2025 18:50:30 +0100 Subject: [PATCH 151/167] Define and use a derive macro for 'UnsizedClone' --- macros/src/lib.rs | 106 +++++++++++++++++++++++++++++ src/new_base/charstr.rs | 3 + src/new_base/name/absolute.rs | 19 ++++-- src/new_base/name/label.rs | 25 +++++-- src/new_base/name/reversed.rs | 22 ++++-- src/new_base/wire/size_prefixed.rs | 4 +- src/new_rdata/basic/cname.rs | 1 + src/new_rdata/basic/mx.rs | 1 + src/new_rdata/basic/ns.rs | 1 + src/new_rdata/basic/ptr.rs | 1 + src/new_rdata/basic/txt.rs | 2 +- src/new_rdata/basic/wks.rs | 2 +- src/new_rdata/dnssec/dnskey.rs | 4 +- src/new_rdata/dnssec/mod.rs | 1 + src/new_rdata/dnssec/nsec.rs | 2 +- src/new_rdata/dnssec/nsec3.rs | 1 + src/new_rdata/edns.rs | 10 ++- 17 files changed, 184 insertions(+), 21 deletions(-) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index a23e6902a..74ce48337 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -610,6 +610,112 @@ pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream { .into() } +//----------- UnsizedClone --------------------------------------------------- + +#[proc_macro_derive(UnsizedClone)] +pub fn derive_unsized_clone(input: pm::TokenStream) -> pm::TokenStream { + fn inner(input: syn::DeriveInput) -> Result { + // Construct an 'ImplSkeleton' so that we can add trait bounds. + let mut skeleton = ImplSkeleton::new(&input, true); + skeleton.bound = + Some(syn::parse_quote!(::domain::utils::UnsizedClone)); + + let struct_data = match &input.data { + syn::Data::Struct(data) if !data.fields.is_empty() => { + let data = Struct::new_as_self(&data.fields); + for field in data.sized_fields() { + skeleton.require_bound( + field.ty.clone(), + syn::parse_quote!(::domain::__core::clone::Clone), + ); + } + + skeleton.require_bound( + data.unsized_field().unwrap().ty.clone(), + syn::parse_quote!(::domain::utils::UnsizedClone), + ); + + Some(data) + } + + syn::Data::Struct(_) => None, + + syn::Data::Enum(data) => { + for variant in data.variants.iter() { + for field in variant.fields.iter() { + skeleton.require_bound( + field.ty.clone(), + syn::parse_quote!(::domain::__core::clone::Clone), + ); + } + } + + None + } + + syn::Data::Union(data) => { + return Err(Error::new_spanned( + data.union_token, + "'UnsizedClone' cannot be 'derive'd for 'union's", + )); + } + }; + + if let Some(data) = struct_data { + let sized_members = data.sized_members(); + let unsized_member = data.unsized_member().unwrap(); + + skeleton.contents.stmts.push(syn::parse_quote! { + unsafe fn unsized_clone(&self, dst: *mut ()) { + let dst = ::domain::utils::UnsizedClone::ptr_with_address(self, dst); + unsafe { + #(::domain::__core::ptr::write( + ::domain::__core::ptr::addr_of_mut!((*dst).#sized_members), + ::domain::__core::clone::Clone::clone(&self.#sized_members), + );)* + ::domain::utils::UnsizedClone::unsized_clone( + &self.#unsized_member, + ::domain::__core::ptr::addr_of_mut!((*dst).#unsized_member) as *mut (), + ); + } + } + }); + + skeleton.contents.stmts.push(syn::parse_quote! { + fn ptr_with_address(&self, addr: *mut ()) -> *mut Self { + ::domain::utils::UnsizedClone::ptr_with_address( + &self.#unsized_member, + addr, + ) as *mut Self + } + }); + } else { + skeleton.contents.stmts.push(syn::parse_quote! { + unsafe fn unsized_clone(&self, dst: *mut ()) { + let dst = dst as *mut Self; + let this = ::domain::__core::clone::Clone::clone(self); + unsafe { + ::domain::__core::ptr::write(dst as *mut Self, this); + } + } + }); + + skeleton.contents.stmts.push(syn::parse_quote! { + fn ptr_with_address(&self, addr: *mut ()) -> *mut Self { + addr as *mut Self + } + }); + } + + Ok(skeleton.into_token_stream()) + } + + let input = syn::parse_macro_input!(input as syn::DeriveInput); + inner(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + //----------- Utility Functions ---------------------------------------------- /// Add a `field_` prefix to member names. diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index 3fd617e90..81e7f5c3a 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -2,6 +2,8 @@ use core::fmt; +use domain_macros::UnsizedClone; + use super::{ build::{self, BuildIntoMessage, BuildResult}, parse::{ParseMessageBytes, SplitMessageBytes}, @@ -11,6 +13,7 @@ use super::{ //----------- CharStr -------------------------------------------------------- /// A DNS "character string". +#[derive(UnsizedClone)] #[repr(transparent)] pub struct CharStr { /// The underlying octets. diff --git a/src/new_base/name/absolute.rs b/src/new_base/name/absolute.rs index c06f41317..0ff9e4d92 100644 --- a/src/new_base/name/absolute.rs +++ b/src/new_base/name/absolute.rs @@ -9,9 +9,14 @@ use core::{ use domain_macros::*; -use crate::new_base::{ - parse::{ParseMessageBytes, SplitMessageBytes}, - wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, +use crate::{ + new_base::{ + parse::{ParseMessageBytes, SplitMessageBytes}, + wire::{ + BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError, + }, + }, + utils::CloneFrom, }; use super::LabelIter; @@ -19,7 +24,7 @@ use super::LabelIter; //----------- Name ----------------------------------------------------------- /// An absolute domain name. -#[derive(AsBytes, BuildBytes)] +#[derive(AsBytes, BuildBytes, UnsizedClone)] #[repr(transparent)] pub struct Name([u8]); @@ -241,6 +246,12 @@ impl NameBuf { } } +impl CloneFrom for NameBuf { + fn clone_from(value: &Self::Target) -> Self { + Self::copy_from(value) + } +} + //--- Parsing from DNS messages impl<'a> SplitMessageBytes<'a> for NameBuf { diff --git a/src/new_base/name/label.rs b/src/new_base/name/label.rs index febae38c0..f76fca4c4 100644 --- a/src/new_base/name/label.rs +++ b/src/new_base/name/label.rs @@ -9,12 +9,17 @@ use core::{ ops::{Deref, DerefMut}, }; -use domain_macros::AsBytes; - -use crate::new_base::{ - build::{BuildIntoMessage, BuildResult, Builder}, - parse::{ParseMessageBytes, SplitMessageBytes}, - wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, +use domain_macros::{AsBytes, UnsizedClone}; + +use crate::{ + new_base::{ + build::{BuildIntoMessage, BuildResult, Builder}, + parse::{ParseMessageBytes, SplitMessageBytes}, + wire::{ + BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError, + }, + }, + utils::CloneFrom, }; //----------- Label ---------------------------------------------------------- @@ -22,7 +27,7 @@ use crate::new_base::{ /// A label in a domain name. /// /// A label contains up to 63 bytes of arbitrary data. -#[derive(AsBytes)] +#[derive(AsBytes, UnsizedClone)] #[repr(transparent)] pub struct Label([u8]); @@ -300,6 +305,12 @@ impl LabelBuf { } } +impl CloneFrom for LabelBuf { + fn clone_from(value: &Self::Target) -> Self { + Self::copy_from(value) + } +} + //--- Parsing from DNS messages impl ParseMessageBytes<'_> for LabelBuf { diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index e9f7c0318..185a3201b 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -8,10 +8,17 @@ use core::{ ops::{Deref, DerefMut}, }; -use crate::new_base::{ - build::{self, BuildIntoMessage, BuildResult}, - parse::{ParseMessageBytes, SplitMessageBytes}, - wire::{BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError}, +use domain_macros::UnsizedClone; + +use crate::{ + new_base::{ + build::{self, BuildIntoMessage, BuildResult}, + parse::{ParseMessageBytes, SplitMessageBytes}, + wire::{ + BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError, + }, + }, + utils::CloneFrom, }; use super::LabelIter; @@ -25,6 +32,7 @@ use super::LabelIter; /// use, making many common operations (e.g. comparing and ordering domain /// names) more computationally expensive. A [`RevName`] stores the labels in /// reversed order for more efficient use. +#[derive(UnsizedClone)] #[repr(transparent)] pub struct RevName([u8]); @@ -258,6 +266,12 @@ impl RevNameBuf { } } +impl CloneFrom for RevNameBuf { + fn clone_from(value: &Self::Target) -> Self { + Self::copy_from(value) + } +} + //--- Parsing from DNS messages impl<'a> SplitMessageBytes<'a> for RevNameBuf { diff --git a/src/new_base/wire/size_prefixed.rs b/src/new_base/wire/size_prefixed.rs index b9583cf2c..6a83022ff 100644 --- a/src/new_base/wire/size_prefixed.rs +++ b/src/new_base/wire/size_prefixed.rs @@ -7,6 +7,8 @@ use core::{ ops::{Deref, DerefMut}, }; +use domain_macros::UnsizedClone; + use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, parse::{ParseMessageBytes, SplitMessageBytes}, @@ -25,7 +27,7 @@ use super::{ /// options). When serialized as bytes, the inner value is prefixed with an /// integer (often a [`U16`](super::U16)) indicating the length of the inner /// value in bytes. -#[derive(Copy, Clone, AsBytes)] +#[derive(Copy, Clone, AsBytes, UnsizedClone)] #[repr(C)] pub struct SizePrefixed { /// The size prefix (needed for 'ParseBytesByRef' / 'AsBytes'). diff --git a/src/new_rdata/basic/cname.rs b/src/new_rdata/basic/cname.rs index 24c794e25..4774dcba9 100644 --- a/src/new_rdata/basic/cname.rs +++ b/src/new_rdata/basic/cname.rs @@ -21,6 +21,7 @@ use crate::new_base::{ BuildBytes, ParseBytes, SplitBytes, + UnsizedClone, )] #[repr(transparent)] pub struct CName { diff --git a/src/new_rdata/basic/mx.rs b/src/new_rdata/basic/mx.rs index bf420e517..6d588bb97 100644 --- a/src/new_rdata/basic/mx.rs +++ b/src/new_rdata/basic/mx.rs @@ -21,6 +21,7 @@ use crate::new_base::{ BuildBytes, ParseBytes, SplitBytes, + UnsizedClone, )] #[repr(C)] pub struct Mx { diff --git a/src/new_rdata/basic/ns.rs b/src/new_rdata/basic/ns.rs index e90f1bdb2..e2e2b0d7b 100644 --- a/src/new_rdata/basic/ns.rs +++ b/src/new_rdata/basic/ns.rs @@ -21,6 +21,7 @@ use crate::new_base::{ BuildBytes, ParseBytes, SplitBytes, + UnsizedClone, )] #[repr(transparent)] pub struct Ns { diff --git a/src/new_rdata/basic/ptr.rs b/src/new_rdata/basic/ptr.rs index 7cd530898..7b659d58a 100644 --- a/src/new_rdata/basic/ptr.rs +++ b/src/new_rdata/basic/ptr.rs @@ -21,6 +21,7 @@ use crate::new_base::{ BuildBytes, ParseBytes, SplitBytes, + UnsizedClone, )] #[repr(transparent)] pub struct Ptr { diff --git a/src/new_rdata/basic/txt.rs b/src/new_rdata/basic/txt.rs index ff8830a0d..7003701af 100644 --- a/src/new_rdata/basic/txt.rs +++ b/src/new_rdata/basic/txt.rs @@ -11,7 +11,7 @@ use crate::new_base::{ //----------- Txt ------------------------------------------------------------ /// Free-form text strings about this domain. -#[derive(AsBytes, BuildBytes)] +#[derive(AsBytes, BuildBytes, UnsizedClone)] #[repr(transparent)] pub struct Txt { /// The text strings, as concatenated [`CharStr`]s. diff --git a/src/new_rdata/basic/wks.rs b/src/new_rdata/basic/wks.rs index 2e7418360..4e4d55da6 100644 --- a/src/new_rdata/basic/wks.rs +++ b/src/new_rdata/basic/wks.rs @@ -9,7 +9,7 @@ use super::A; //----------- Wks ------------------------------------------------------------ /// Well-known services supported on this domain. -#[derive(AsBytes, BuildBytes, ParseBytesByRef)] +#[derive(AsBytes, BuildBytes, ParseBytesByRef, UnsizedClone)] #[repr(C, packed)] pub struct Wks { /// The address of the host providing these services. diff --git a/src/new_rdata/dnssec/dnskey.rs b/src/new_rdata/dnssec/dnskey.rs index 0e9c391a9..080e766bf 100644 --- a/src/new_rdata/dnssec/dnskey.rs +++ b/src/new_rdata/dnssec/dnskey.rs @@ -9,7 +9,9 @@ use super::SecAlg; //----------- DNSKey --------------------------------------------------------- /// A cryptographic key for DNS security. -#[derive(Debug, PartialEq, Eq, AsBytes, BuildBytes, ParseBytesByRef)] +#[derive( + Debug, PartialEq, Eq, AsBytes, BuildBytes, ParseBytesByRef, UnsizedClone, +)] #[repr(C)] pub struct DNSKey { /// Flags describing the usage of the key. diff --git a/src/new_rdata/dnssec/mod.rs b/src/new_rdata/dnssec/mod.rs index 6182b615c..b742ecbcd 100644 --- a/src/new_rdata/dnssec/mod.rs +++ b/src/new_rdata/dnssec/mod.rs @@ -38,6 +38,7 @@ pub use ds::{DigestType, Ds}; ParseBytesByRef, SplitBytes, SplitBytesByRef, + UnsizedClone, )] #[repr(transparent)] pub struct SecAlg { diff --git a/src/new_rdata/dnssec/nsec.rs b/src/new_rdata/dnssec/nsec.rs index b7b47191d..ffa9db1c9 100644 --- a/src/new_rdata/dnssec/nsec.rs +++ b/src/new_rdata/dnssec/nsec.rs @@ -36,7 +36,7 @@ impl NSec<'_> { //----------- TypeBitmaps ---------------------------------------------------- /// A bitmap of DNS record types. -#[derive(PartialEq, Eq, AsBytes, BuildBytes)] +#[derive(PartialEq, Eq, AsBytes, BuildBytes, UnsizedClone)] #[repr(transparent)] pub struct TypeBitmaps { octets: [u8], diff --git a/src/new_rdata/dnssec/nsec3.rs b/src/new_rdata/dnssec/nsec3.rs index f8b03b049..46c096357 100644 --- a/src/new_rdata/dnssec/nsec3.rs +++ b/src/new_rdata/dnssec/nsec3.rs @@ -58,6 +58,7 @@ impl NSec3<'_> { BuildBytes, ParseBytesByRef, SplitBytesByRef, + UnsizedClone, )] #[repr(C)] pub struct NSec3Param { diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index d03b739a4..058c776da 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -15,7 +15,15 @@ use crate::{ /// Extended DNS options. #[derive( - PartialEq, Eq, PartialOrd, Ord, Hash, AsBytes, BuildBytes, ParseBytesByRef, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + AsBytes, + BuildBytes, + ParseBytesByRef, + UnsizedClone, )] #[repr(transparent)] pub struct Opt { From 4fd86d447cebc0cc5a33517a0fe64a8a25f92c5b Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 18 Mar 2025 11:33:23 +0100 Subject: [PATCH 152/167] [utils] Define 'clone_to_bump()' for integrating with 'bumpalo' --- Cargo.toml | 4 ++-- src/utils/mod.rs | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 411cb7a42..0a79794e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,12 +55,12 @@ tracing-subscriber = { version = "0.3.18", optional = true, features = ["env-fil default = ["std", "rand"] # Support for libraries -bumpalo = ["dep:bumpalo"] +bumpalo = ["dep:bumpalo", "std"] bytes = ["dep:bytes", "octseq/bytes"] heapless = ["dep:heapless", "octseq/heapless"] serde = ["dep:serde", "octseq/serde"] smallvec = ["dep:smallvec", "octseq/smallvec"] -std = ["dep:hashbrown", "bytes?/std", "octseq/std", "time/std"] +std = ["dep:hashbrown", "bumpalo?/std", "bytes?/std", "octseq/std", "time/std"] tracing = ["dep:log", "dep:tracing"] # Cryptographic backends diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 262ad3e58..01a8259a2 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -251,3 +251,24 @@ impl CloneFrom for Vec { value.into() } } + +//----------- clone_to_bump -------------------------------------------------- + +/// Clone a value into a [`Bump`] allocator. +/// +/// This works with [`UnsizedClone`] values, which extends [`Bump`]'s native +/// functionality. +#[cfg(feature = "bumpalo")] +#[allow(clippy::mut_from_ref)] // using a memory allocator +pub fn clone_to_bump<'a, T: ?Sized + UnsizedClone>( + value: &T, + bump: &'a bumpalo::Bump, +) -> &'a mut T { + let layout = Layout::for_value(value); + let ptr = bump.alloc_layout(layout).as_ptr().cast::<()>(); + unsafe { + value.unsized_clone(ptr); + }; + let ptr = value.ptr_with_address(ptr); + unsafe { &mut *ptr } +} From 00d860253b5ec86616dbbd935b509445d1ca59ac Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Tue, 18 Mar 2025 11:42:41 +0100 Subject: [PATCH 153/167] Replace most 'clone_to_bump()'s via 'UnsizedClone' --- src/new_base/charstr.rs | 13 ------------- src/new_base/name/absolute.rs | 13 ------------- src/new_base/name/reversed.rs | 13 ------------- src/new_base/wire/size_prefixed.rs | 17 ----------------- src/new_rdata/basic/hinfo.rs | 6 ++++-- src/new_rdata/basic/txt.rs | 11 ----------- src/new_rdata/basic/wks.rs | 15 --------------- src/new_rdata/dnssec/dnskey.rs | 13 ------------- src/new_rdata/dnssec/ds.rs | 19 +++---------------- src/new_rdata/dnssec/mod.rs | 1 - src/new_rdata/dnssec/nsec.rs | 21 ++++----------------- src/new_rdata/dnssec/nsec3.rs | 23 +++++------------------ src/new_rdata/dnssec/rrsig.rs | 4 +++- src/new_rdata/edns.rs | 15 --------------- src/new_rdata/mod.rs | 14 ++++++++------ 15 files changed, 27 insertions(+), 171 deletions(-) diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index 81e7f5c3a..4b02a6097 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -20,19 +20,6 @@ pub struct CharStr { pub octets: [u8], } -//--- Interaction - -impl CharStr { - /// Copy this into the given [`Bump`] allocator. - #[cfg(feature = "bumpalo")] - #[allow(clippy::mut_from_ref)] // using a memory allocator - pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { - let octets = bump.alloc_slice_copy(&self.octets); - // SAFETY: 'CharStr' is 'repr(transparent)' to '[u8]'. - unsafe { core::mem::transmute::<&mut [u8], &mut CharStr>(octets) } - } -} - //--- Parsing from DNS messages impl<'a> SplitMessageBytes<'a> for &'a CharStr { diff --git a/src/new_base/name/absolute.rs b/src/new_base/name/absolute.rs index 0ff9e4d92..975af3cd5 100644 --- a/src/new_base/name/absolute.rs +++ b/src/new_base/name/absolute.rs @@ -100,19 +100,6 @@ impl Name { } } -//--- Interaction - -impl Name { - /// Copy this into the given [`Bump`] allocator. - #[cfg(feature = "bumpalo")] - #[allow(clippy::mut_from_ref)] // using a memory allocator - pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { - let bytes = bump.alloc_slice_copy(self.as_bytes()); - // SAFETY: 'AsBytes' is a transmute, so we can transmute back. - unsafe { core::mem::transmute::<&mut [u8], &mut Self>(bytes) } - } -} - //--- Parsing from bytes impl<'a> ParseBytes<'a> for &'a Name { diff --git a/src/new_base/name/reversed.rs b/src/new_base/name/reversed.rs index 185a3201b..5fcb92735 100644 --- a/src/new_base/name/reversed.rs +++ b/src/new_base/name/reversed.rs @@ -113,19 +113,6 @@ impl RevName { } } -//--- Interaction - -impl RevName { - /// Copy this into the given [`Bump`] allocator. - #[cfg(feature = "bumpalo")] - #[allow(clippy::mut_from_ref)] // using a memory allocator - pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { - let octets = bump.alloc_slice_copy(self.as_bytes()); - // SAFETY: 'RevName' is 'repr(transparent)' to '[u8]'. - unsafe { core::mem::transmute::<&mut [u8], &mut Self>(octets) } - } -} - //--- Building into DNS messages impl BuildIntoMessage for RevName { diff --git a/src/new_base/wire/size_prefixed.rs b/src/new_base/wire/size_prefixed.rs index 6a83022ff..9acac09bd 100644 --- a/src/new_base/wire/size_prefixed.rs +++ b/src/new_base/wire/size_prefixed.rs @@ -63,23 +63,6 @@ where } } -//--- Interaction - -impl SizePrefixed -where - S: AsBytes + SplitBytesByRef + Copy + TryFrom + TryInto, - T: AsBytes + ParseBytesByRef, -{ - /// Copy this into the given [`Bump`] allocator. - #[cfg(feature = "bumpalo")] - #[allow(clippy::mut_from_ref)] // using a memory allocator - pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { - let bytes = bump.alloc_slice_copy(self.as_bytes()); - // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. - unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } - } -} - //--- Conversion from the inner data impl From for SizePrefixed diff --git a/src/new_rdata/basic/hinfo.rs b/src/new_rdata/basic/hinfo.rs index ea6e4ad93..0626855de 100644 --- a/src/new_rdata/basic/hinfo.rs +++ b/src/new_rdata/basic/hinfo.rs @@ -20,9 +20,11 @@ impl HInfo<'_> { /// Copy referenced data into the given [`Bump`] allocator. #[cfg(feature = "bumpalo")] pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> HInfo<'r> { + use crate::utils::clone_to_bump; + HInfo { - cpu: self.cpu.clone_to_bump(bump), - os: self.os.clone_to_bump(bump), + cpu: clone_to_bump(self.cpu, bump), + os: clone_to_bump(self.os, bump), } } } diff --git a/src/new_rdata/basic/txt.rs b/src/new_rdata/basic/txt.rs index 7003701af..611492149 100644 --- a/src/new_rdata/basic/txt.rs +++ b/src/new_rdata/basic/txt.rs @@ -23,17 +23,6 @@ pub struct Txt { //--- Interaction impl Txt { - /// Copy this into the given [`Bump`] allocator. - #[cfg(feature = "bumpalo")] - #[allow(clippy::mut_from_ref)] // using a memory allocator - pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { - use crate::new_base::wire::AsBytes; - - let bytes = bump.alloc_slice_copy(self.as_bytes()); - // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. - unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } - } - /// Iterate over the [`CharStr`]s in this record. pub fn iter(&self) -> impl Iterator + '_ { // NOTE: A TXT record always has at least one 'CharStr' within. diff --git a/src/new_rdata/basic/wks.rs b/src/new_rdata/basic/wks.rs index 4e4d55da6..1facb6f8d 100644 --- a/src/new_rdata/basic/wks.rs +++ b/src/new_rdata/basic/wks.rs @@ -22,21 +22,6 @@ pub struct Wks { pub ports: [u8], } -//--- Interaction - -impl Wks { - /// Copy this into the given [`Bump`] allocator. - #[cfg(feature = "bumpalo")] - #[allow(clippy::mut_from_ref)] // using a memory allocator - pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { - use crate::new_base::wire::ParseBytesByRef; - - let bytes = bump.alloc_slice_copy(self.as_bytes()); - // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. - unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } - } -} - //--- Formatting impl fmt::Debug for Wks { diff --git a/src/new_rdata/dnssec/dnskey.rs b/src/new_rdata/dnssec/dnskey.rs index 080e766bf..52b0abf2f 100644 --- a/src/new_rdata/dnssec/dnskey.rs +++ b/src/new_rdata/dnssec/dnskey.rs @@ -27,19 +27,6 @@ pub struct DNSKey { pub key: [u8], } -impl DNSKey { - /// Copy this into the given [`Bump`] allocator. - #[cfg(feature = "bumpalo")] - #[allow(clippy::mut_from_ref)] // using a memory allocator - pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { - use crate::new_base::wire::{AsBytes, ParseBytesByRef}; - - let bytes = bump.alloc_slice_copy(self.as_bytes()); - // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. - unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } - } -} - //----------- DNSKeyFlags ---------------------------------------------------- /// Flags describing a [`DNSKey`]. diff --git a/src/new_rdata/dnssec/ds.rs b/src/new_rdata/dnssec/ds.rs index ed92dd1f4..f4c265133 100644 --- a/src/new_rdata/dnssec/ds.rs +++ b/src/new_rdata/dnssec/ds.rs @@ -9,7 +9,9 @@ use super::SecAlg; //----------- Ds ------------------------------------------------------------- /// The signing key for a delegated zone. -#[derive(Debug, PartialEq, Eq, AsBytes, BuildBytes, ParseBytesByRef)] +#[derive( + Debug, PartialEq, Eq, AsBytes, BuildBytes, ParseBytesByRef, UnsizedClone, +)] #[repr(C)] pub struct Ds { /// The key tag of the signing key. @@ -25,21 +27,6 @@ pub struct Ds { pub digest: [u8], } -//--- Interaction - -impl Ds { - /// Copy this into the given [`Bump`] allocator. - #[cfg(feature = "bumpalo")] - #[allow(clippy::mut_from_ref)] // using a memory allocator - pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { - use crate::new_base::wire::{AsBytes, ParseBytesByRef}; - - let bytes = bump.alloc_slice_copy(self.as_bytes()); - // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. - unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } - } -} - //----------- DigestType ----------------------------------------------------- /// A cryptographic digest algorithm. diff --git a/src/new_rdata/dnssec/mod.rs b/src/new_rdata/dnssec/mod.rs index b742ecbcd..6182b615c 100644 --- a/src/new_rdata/dnssec/mod.rs +++ b/src/new_rdata/dnssec/mod.rs @@ -38,7 +38,6 @@ pub use ds::{DigestType, Ds}; ParseBytesByRef, SplitBytes, SplitBytesByRef, - UnsizedClone, )] #[repr(transparent)] pub struct SecAlg { diff --git a/src/new_rdata/dnssec/nsec.rs b/src/new_rdata/dnssec/nsec.rs index ffa9db1c9..eaaa97f56 100644 --- a/src/new_rdata/dnssec/nsec.rs +++ b/src/new_rdata/dnssec/nsec.rs @@ -26,9 +26,11 @@ impl NSec<'_> { /// Copy referenced data into the given [`Bump`] allocator. #[cfg(feature = "bumpalo")] pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> NSec<'r> { + use crate::utils::clone_to_bump; + NSec { - next: self.next.clone_to_bump(bump), - types: self.types.clone_to_bump(bump), + next: clone_to_bump(self.next, bump), + types: clone_to_bump(self.types, bump), } } } @@ -69,21 +71,6 @@ impl TypeBitmaps { } } -//--- Interaction - -impl TypeBitmaps { - /// Copy this into the given [`Bump`] allocator. - #[cfg(feature = "bumpalo")] - #[allow(clippy::mut_from_ref)] // using a memory allocator - pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { - use crate::new_base::wire::{AsBytes, ParseBytesByRef}; - - let bytes = bump.alloc_slice_copy(self.as_bytes()); - // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. - unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } - } -} - //--- Formatting impl fmt::Debug for TypeBitmaps { diff --git a/src/new_rdata/dnssec/nsec3.rs b/src/new_rdata/dnssec/nsec3.rs index 46c096357..84f789318 100644 --- a/src/new_rdata/dnssec/nsec3.rs +++ b/src/new_rdata/dnssec/nsec3.rs @@ -36,13 +36,15 @@ impl NSec3<'_> { /// Copy referenced data into the given [`Bump`] allocator. #[cfg(feature = "bumpalo")] pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> NSec3<'r> { + use crate::utils::clone_to_bump; + NSec3 { algorithm: self.algorithm, flags: self.flags, iterations: self.iterations, - salt: self.salt.clone_to_bump(bump), - next: self.next.clone_to_bump(bump), - types: self.types.clone_to_bump(bump), + salt: clone_to_bump(self.salt, bump), + next: clone_to_bump(self.next, bump), + types: clone_to_bump(self.types, bump), } } } @@ -75,21 +77,6 @@ pub struct NSec3Param { pub salt: SizePrefixed, } -//--- Interaction - -impl NSec3Param { - /// Copy this into the given [`Bump`] allocator. - #[cfg(feature = "bumpalo")] - #[allow(clippy::mut_from_ref)] // using a memory allocator - pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { - use crate::new_base::wire::{AsBytes, ParseBytesByRef}; - - let bytes = bump.alloc_slice_copy(self.as_bytes()); - // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. - unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } - } -} - //----------- NSec3HashAlg --------------------------------------------------- /// The hash algorithm used with [`NSec3`] records. diff --git a/src/new_rdata/dnssec/rrsig.rs b/src/new_rdata/dnssec/rrsig.rs index 60f9fddb3..8fd348ccd 100644 --- a/src/new_rdata/dnssec/rrsig.rs +++ b/src/new_rdata/dnssec/rrsig.rs @@ -43,8 +43,10 @@ impl RRSig<'_> { /// Copy referenced data into the given [`Bump`] allocator. #[cfg(feature = "bumpalo")] pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> RRSig<'r> { + use crate::utils::clone_to_bump; + RRSig { - signer: self.signer.clone_to_bump(bump), + signer: clone_to_bump(self.signer, bump), signature: bump.alloc_slice_copy(self.signature), ..self.clone() } diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index 058c776da..9389d6984 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -40,21 +40,6 @@ impl Opt { } } -//--- Interaction - -impl Opt { - /// Copy this into the given [`Bump`] allocator. - #[cfg(feature = "bumpalo")] - #[allow(clippy::mut_from_ref)] // using a memory allocator - pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { - use crate::new_base::wire::{AsBytes, ParseBytesByRef}; - - let bytes = bump.alloc_slice_copy(self.as_bytes()); - // SAFETY: 'ParseBytesByRef' and 'AsBytes' are inverses. - unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } - } -} - //--- Formatting impl fmt::Debug for Opt { diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 540cbed6e..da6f1fe96 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -181,25 +181,27 @@ impl<'a, N> RecordData<'a, N> { where N: Clone, { + use crate::utils::clone_to_bump; + match self { Self::A(&r) => RecordData::A(bump.alloc(r)), Self::Ns(r) => RecordData::Ns(r.clone()), Self::CName(r) => RecordData::CName(r.clone()), Self::Soa(r) => RecordData::Soa(r.clone()), - Self::Wks(r) => RecordData::Wks(r.clone_to_bump(bump)), + Self::Wks(r) => RecordData::Wks(clone_to_bump(*r, bump)), Self::Ptr(r) => RecordData::Ptr(r.clone()), Self::HInfo(r) => RecordData::HInfo(r.clone_to_bump(bump)), Self::Mx(r) => RecordData::Mx(r.clone()), - Self::Txt(r) => RecordData::Txt(r.clone_to_bump(bump)), + Self::Txt(r) => RecordData::Txt(clone_to_bump(*r, bump)), Self::Aaaa(&r) => RecordData::Aaaa(bump.alloc(r)), - Self::Opt(r) => RecordData::Opt(r.clone_to_bump(bump)), - Self::Ds(r) => RecordData::Ds(r.clone_to_bump(bump)), + Self::Opt(r) => RecordData::Opt(clone_to_bump(*r, bump)), + Self::Ds(r) => RecordData::Ds(clone_to_bump(*r, bump)), Self::RRSig(r) => RecordData::RRSig(r.clone_to_bump(bump)), Self::NSec(r) => RecordData::NSec(r.clone_to_bump(bump)), - Self::DNSKey(r) => RecordData::DNSKey(r.clone_to_bump(bump)), + Self::DNSKey(r) => RecordData::DNSKey(clone_to_bump(*r, bump)), Self::NSec3(r) => RecordData::NSec3(r.clone_to_bump(bump)), Self::NSec3Param(r) => { - RecordData::NSec3Param(r.clone_to_bump(bump)) + RecordData::NSec3Param(clone_to_bump(*r, bump)) } Self::Unknown(rt, rd) => { RecordData::Unknown(*rt, rd.clone_to_bump(bump)) From 940dd47eb7702b2627e2e0da6d347885c31f4ba4 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 19 Mar 2025 10:59:53 +0100 Subject: [PATCH 154/167] Fix broken doc links --- src/new_rdata/basic/hinfo.rs | 2 +- src/new_rdata/dnssec/nsec.rs | 2 +- src/new_rdata/dnssec/nsec3.rs | 2 +- src/new_rdata/dnssec/rrsig.rs | 2 +- src/new_rdata/mod.rs | 4 ++-- src/utils/mod.rs | 4 +++- 6 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/new_rdata/basic/hinfo.rs b/src/new_rdata/basic/hinfo.rs index 0626855de..3f5674a1f 100644 --- a/src/new_rdata/basic/hinfo.rs +++ b/src/new_rdata/basic/hinfo.rs @@ -17,7 +17,7 @@ pub struct HInfo<'a> { //--- Interaction impl HInfo<'_> { - /// Copy referenced data into the given [`Bump`] allocator. + /// Copy referenced data into the given [`Bump`](bumpalo::Bump) allocator. #[cfg(feature = "bumpalo")] pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> HInfo<'r> { use crate::utils::clone_to_bump; diff --git a/src/new_rdata/dnssec/nsec.rs b/src/new_rdata/dnssec/nsec.rs index eaaa97f56..24b8a45e7 100644 --- a/src/new_rdata/dnssec/nsec.rs +++ b/src/new_rdata/dnssec/nsec.rs @@ -23,7 +23,7 @@ pub struct NSec<'a> { //--- Interaction impl NSec<'_> { - /// Copy referenced data into the given [`Bump`] allocator. + /// Copy referenced data into the given [`Bump`](bumpalo::Bump) allocator. #[cfg(feature = "bumpalo")] pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> NSec<'r> { use crate::utils::clone_to_bump; diff --git a/src/new_rdata/dnssec/nsec3.rs b/src/new_rdata/dnssec/nsec3.rs index 84f789318..56877f4ef 100644 --- a/src/new_rdata/dnssec/nsec3.rs +++ b/src/new_rdata/dnssec/nsec3.rs @@ -33,7 +33,7 @@ pub struct NSec3<'a> { //--- Interaction impl NSec3<'_> { - /// Copy referenced data into the given [`Bump`] allocator. + /// Copy referenced data into the given [`Bump`](bumpalo::Bump) allocator. #[cfg(feature = "bumpalo")] pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> NSec3<'r> { use crate::utils::clone_to_bump; diff --git a/src/new_rdata/dnssec/rrsig.rs b/src/new_rdata/dnssec/rrsig.rs index 8fd348ccd..a717caa9e 100644 --- a/src/new_rdata/dnssec/rrsig.rs +++ b/src/new_rdata/dnssec/rrsig.rs @@ -40,7 +40,7 @@ pub struct RRSig<'a> { //--- Interaction impl RRSig<'_> { - /// Copy referenced data into the given [`Bump`] allocator. + /// Copy referenced data into the given [`Bump`](bumpalo::Bump) allocator. #[cfg(feature = "bumpalo")] pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> RRSig<'r> { use crate::utils::clone_to_bump; diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index da6f1fe96..82c408259 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -172,7 +172,7 @@ impl<'a, N> RecordData<'a, N> { } } - /// Copy referenced data into the given [`Bump`] allocator. + /// Copy referenced data into the given [`Bump`](bumpalo::Bump) allocator. #[cfg(feature = "bumpalo")] pub fn clone_to_bump<'r>( &self, @@ -378,7 +378,7 @@ pub struct UnknownRecordData { //--- Interaction impl UnknownRecordData { - /// Copy this into the given [`Bump`] allocator. + /// Copy referenced data into the given [`Bump`](bumpalo::Bump) allocator. #[cfg(feature = "bumpalo")] #[allow(clippy::mut_from_ref)] // using a memory allocator pub fn clone_to_bump<'r>(&self, bump: &'r bumpalo::Bump) -> &'r mut Self { diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 01a8259a2..00e4d12c0 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -31,7 +31,7 @@ pub unsafe trait UnsizedClone { /// Change the address of a pointer to [`Self`]. /// /// When [`Self`] is used as the last field in a type that also implements - /// [`ParseBytesByRef`], it may be dynamically sized, and so a pointer (or + /// [`UnsizedClone`], it may be dynamically sized, and so a pointer (or /// reference) to it may include additional metadata. This metadata is /// included verbatim in any reference/pointer to the containing type. /// @@ -258,6 +258,8 @@ impl CloneFrom for Vec { /// /// This works with [`UnsizedClone`] values, which extends [`Bump`]'s native /// functionality. +/// +/// [`Bump`]: bumpalo::Bump #[cfg(feature = "bumpalo")] #[allow(clippy::mut_from_ref)] // using a memory allocator pub fn clone_to_bump<'a, T: ?Sized + UnsizedClone>( From b405e95a1e94e961e7b6cdd1a62d30c9d242f1a9 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 19 Mar 2025 11:03:08 +0100 Subject: [PATCH 155/167] [new_base/name/absolute] Impl 'Ord' for 'Name' --- src/new_base/name/absolute.rs | 94 ++++++++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 1 deletion(-) diff --git a/src/new_base/name/absolute.rs b/src/new_base/name/absolute.rs index 975af3cd5..70d932e1b 100644 --- a/src/new_base/name/absolute.rs +++ b/src/new_base/name/absolute.rs @@ -2,6 +2,7 @@ use core::{ borrow::{Borrow, BorrowMut}, + cmp::Ordering, fmt, hash::{Hash, Hasher}, ops::{Deref, DerefMut}, @@ -163,6 +164,85 @@ impl PartialEq for Name { impl Eq for Name {} +//--- Comparison + +impl PartialOrd for Name { + fn partial_cmp(&self, that: &Self) -> Option { + Some(self.cmp(that)) + } +} + +impl Ord for Name { + fn cmp(&self, that: &Self) -> Ordering { + // We wish to compare the labels in these names in reverse order. + // Unfortunately, labels in absolute names cannot be traversed + // backwards efficiently. We need to try harder. + // + // Consider two names that are not equal. This means that one name is + // a strict suffix of the other, or that the two had different labels + // at some position. Following this mismatched label is a suffix of + // labels that both names do agree on. + // + // We traverse the bytes in the names in reverse order and find the + // length of their shared suffix. The actual shared suffix, in units + // of labels, may be shorter than this (because the last bytes of the + // mismatched labels could be the same). + // + // Then, we traverse the labels of both names in forward order, until + // we hit the shared suffix territory. We try to match up the names + // in order to discover the real shared suffix. Once the suffix is + // found, the immediately preceding label (if there is one) contains + // the inequality, and can be compared as usual. + + let suffix_len = core::iter::zip( + self.as_bytes().iter().rev().map(u8::to_ascii_lowercase), + that.as_bytes().iter().rev().map(u8::to_ascii_lowercase), + ) + .position(|(a, b)| a != b); + + let Some(suffix_len) = suffix_len else { + // 'iter::zip()' simply ignores unequal iterators, stopping when + // either iterator finishes. Even though the two names had no + // mismatching bytes, one could be longer than the other. + return self.len().cmp(&that.len()); + }; + + // Prepare for forward traversal. + let (mut lhs, mut rhs) = (self.labels(), that.labels()); + // SAFETY: There is at least one unequal byte, and it cannot be the + // root label, so both names have at least one additional label. + let mut prev = unsafe { + (lhs.next().unwrap_unchecked(), rhs.next().unwrap_unchecked()) + }; + + // Traverse both names in lockstep, trying to match their lengths. + loop { + let (llen, rlen) = (lhs.remaining().len(), rhs.remaining().len()); + if llen == rlen && llen <= suffix_len { + // We're in shared suffix territory, and 'lhs' and 'rhs' have + // the same length. Thus, they must be identical, and we have + // found the shared suffix. + break prev.0.cmp(prev.1); + } else if llen > rlen { + // Try to match the lengths by shortening 'lhs'. + + // SAFETY: 'llen > rlen >= 1', thus 'lhs' contains at least + // one additional label before the root. + prev.0 = unsafe { lhs.next().unwrap_unchecked() }; + } else { + // Try to match the lengths by shortening 'rhs'. + + // SAFETY: Either: + // - '1 <= llen < rlen', thus 'rhs' contains at least one + // additional label before the root. + // - 'llen == rlen > suffix_len >= 1', thus 'rhs' contains at + // least one additional label before the root. + prev.1 = unsafe { rhs.next().unwrap_unchecked() }; + } + } + } +} + //--- Hashing impl Hash for Name { @@ -452,7 +532,7 @@ impl AsMut for NameBuf { } } -//--- Forwarding equality, hashing, and formatting +//--- Forwarding equality, comparison, hashing, and formatting impl PartialEq for NameBuf { fn eq(&self, that: &Self) -> bool { @@ -462,6 +542,18 @@ impl PartialEq for NameBuf { impl Eq for NameBuf {} +impl PartialOrd for NameBuf { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for NameBuf { + fn cmp(&self, other: &Self) -> Ordering { + (**self).cmp(&**other) + } +} + impl Hash for NameBuf { fn hash(&self, state: &mut H) { (**self).hash(state) From c9040101f0097c4e87251b4e0589f7f3a6151095 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 19 Mar 2025 14:25:58 +0100 Subject: [PATCH 156/167] [new_base/message] Impl 'UnsizedClone' for 'Message' --- src/new_base/message.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/new_base/message.rs b/src/new_base/message.rs index 27e4cde88..7df0c47e1 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -9,7 +9,7 @@ use super::wire::{AsBytes, ParseBytesByRef, U16}; //----------- Message -------------------------------------------------------- /// A DNS message. -#[derive(AsBytes, BuildBytes, ParseBytesByRef)] +#[derive(AsBytes, BuildBytes, ParseBytesByRef, UnsizedClone)] #[repr(C, packed)] pub struct Message { /// The message header. From 0366c2065debb86e59cc301c22ba0d6a57d40d84 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 19 Mar 2025 15:13:16 +0100 Subject: [PATCH 157/167] [new_base/name] Introduce 'CanonicalName' 'CanonicalName' introduces important operations for DNSSEC. It is only implemented by 'Name' (not 'RevName') because 'RevName' cannot implement it efficiently (users should prefer 'RecordData<&Name>'). --- src/new_base/name/absolute.rs | 17 ++++- src/new_base/name/mod.rs | 140 ++++++++++++++++++++++++++++++++++ src/new_base/wire/build.rs | 9 +++ 3 files changed, 165 insertions(+), 1 deletion(-) diff --git a/src/new_base/name/absolute.rs b/src/new_base/name/absolute.rs index 70d932e1b..0b14263c5 100644 --- a/src/new_base/name/absolute.rs +++ b/src/new_base/name/absolute.rs @@ -20,7 +20,7 @@ use crate::{ utils::CloneFrom, }; -use super::LabelIter; +use super::{CanonicalName, LabelIter}; //----------- Name ----------------------------------------------------------- @@ -101,6 +101,21 @@ impl Name { } } +//--- Canonical operations + +impl CanonicalName for Name { + fn cmp_composed(&self, other: &Self) -> Ordering { + self.as_bytes().cmp(other.as_bytes()) + } + + fn cmp_lowercase_composed(&self, other: &Self) -> Ordering { + self.as_bytes() + .iter() + .map(u8::to_ascii_lowercase) + .cmp(other.as_bytes().iter().map(u8::to_ascii_lowercase)) + } +} + //--- Parsing from bytes impl<'a> ParseBytes<'a> for &'a Name { diff --git a/src/new_base/name/mod.rs b/src/new_base/name/mod.rs index c1cc2df2d..63cf7f71e 100644 --- a/src/new_base/name/mod.rs +++ b/src/new_base/name/mod.rs @@ -14,6 +14,12 @@ //! with the `.example.org` suffix. The reverse order in which this hierarchy //! is expressed can sometimes be confusing. +use core::cmp::Ordering; + +use super::wire::{BuildBytes, TruncationError}; + +//--- Submodules + mod label; pub use label::{Label, LabelBuf, LabelIter}; @@ -25,3 +31,137 @@ pub use reversed::{RevName, RevNameBuf}; mod unparsed; pub use unparsed::UnparsedName; + +//----------- CanonicalName -------------------------------------------------- + +/// DNSSEC-conformant operations for domain names. +/// +/// As specified by [RFC 4034, section 6], domain names are used in two +/// different ways: they can be serialized into byte strings or compared. +/// +/// - In record data, they are serialized following the regular wire format +/// (specifically without name compression). However, in some record data +/// types, labels are converted to lowercase for serialization. +/// +/// - In record owner names, they are compared from the root label outwards, +/// with the contents of each label being compared case-insensitively. +/// +/// - In record data, they are compared as serialized byte strings. As +/// explained above, there are two different valid serializations (i.e. the +/// labels may be lowercased, or the original case may be retained). +/// +/// [RFC 4034, section 6]: https://datatracker.ietf.org/doc/html/rfc4034#section-6 +/// +/// If a domain name type implements [`CanonicalName`], then [`BuildBytes`] +/// will serialize the name in the wire format (without changing the case of +/// its labels). [`Ord`] will compare domain names as if they were the owner +/// names of records (i.e. not as if they were serialized byte strings). +pub trait CanonicalName: BuildBytes + Ord { + /// Serialize a domain name with lowercased labels. + /// + /// This is subtly different from [`BuildBytes`]; it requires all the + /// characters in the domain name to be lowercased. It is implemented + /// automatically, but it could be overriden for performance. + fn build_lowercased_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + // Build the bytes as usual. + let rest = self.build_bytes(bytes)?.len(); + + // Find the built bytes and lowercase them. + let (bytes, rest) = bytes.split_at_mut(rest); + bytes.make_ascii_lowercase(); + + Ok(rest) + } + + /// Compare domain names as if they were in the wire format. + /// + /// This is equivalent to serializing both domain names in the wire format + /// using [`BuildBytes`] and comparing the resulting byte strings. It is + /// implemented automatically, but it could be overriden for performance. + fn cmp_composed(&self, other: &Self) -> Ordering { + // Build both names into byte arrays. + + let mut this = [0u8; 255]; + let rest_len = self + .build_bytes(&mut this) + .expect("domain names are at most 255 bytes when serialized") + .len(); + let this = &this[..this.len() - rest_len]; + + let mut that = [0u8; 255]; + let rest_len = other + .build_bytes(&mut that) + .expect("domain names are at most 255 bytes when serialized") + .len(); + let that = &that[..that.len() - rest_len]; + + // Compare the byte strings. + this.cmp(that) + } + + /// Compare domain names as if they were in the wire format, lowercased. + /// + /// This is equivalent to serializing both domain names in the wire format + /// using [`build_lowercased_bytes()`] and comparing the resulting byte + /// strings. It is implemented automatically, but it could be overriden + /// for performance. + /// + /// [`build_lowercased_bytes()`]: Self::build_lowercased_bytes() + fn cmp_lowercase_composed(&self, other: &Self) -> Ordering { + // Build both names into byte arrays. + + let mut this = [0u8; 255]; + let rest_len = self + .build_lowercased_bytes(&mut this) + .expect("domain names are at most 255 bytes when serialized") + .len(); + let this = &this[..this.len() - rest_len]; + + let mut that = [0u8; 255]; + let rest_len = other + .build_lowercased_bytes(&mut that) + .expect("domain names are at most 255 bytes when serialized") + .len(); + let that = &that[..that.len() - rest_len]; + + // Compare the byte strings. + this.cmp(that) + } +} + +impl CanonicalName for &N { + fn build_lowercased_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + (**self).build_lowercased_bytes(bytes) + } + + fn cmp_composed(&self, other: &Self) -> Ordering { + (**self).cmp_composed(*other) + } + + fn cmp_lowercase_composed(&self, other: &Self) -> Ordering { + (**self).cmp_lowercase_composed(*other) + } +} + +impl CanonicalName for &mut N { + fn build_lowercased_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + (**self).build_lowercased_bytes(bytes) + } + + fn cmp_composed(&self, other: &Self) -> Ordering { + (**self).cmp_composed(*other) + } + + fn cmp_lowercase_composed(&self, other: &Self) -> Ordering { + (**self).cmp_lowercase_composed(*other) + } +} diff --git a/src/new_base/wire/build.rs b/src/new_base/wire/build.rs index 88b6a44b3..e45c047d2 100644 --- a/src/new_base/wire/build.rs +++ b/src/new_base/wire/build.rs @@ -29,6 +29,15 @@ impl BuildBytes for &T { } } +impl BuildBytes for &mut T { + fn build_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + T::build_bytes(*self, bytes) + } +} + impl BuildBytes for u8 { fn build_bytes<'b>( &self, From 2623bb85eef3c6d9049e72e9a49d47d86b6215f9 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Wed, 19 Mar 2025 16:26:00 +0100 Subject: [PATCH 158/167] Implement DNSSEC canonical ordering and building --- src/new_base/charstr.rs | 19 ++++++++ src/new_base/mod.rs | 4 +- src/new_base/record.rs | 80 +++++++++++++++++++++++++++++++++- src/new_rdata/basic/a.rs | 10 +++++ src/new_rdata/basic/cname.rs | 21 ++++++++- src/new_rdata/basic/hinfo.rs | 24 +++++++++- src/new_rdata/basic/mx.rs | 25 ++++++++++- src/new_rdata/basic/ns.rs | 21 ++++++++- src/new_rdata/basic/ptr.rs | 21 ++++++++- src/new_rdata/basic/soa.rs | 36 ++++++++++++++- src/new_rdata/basic/txt.rs | 12 ++++- src/new_rdata/basic/wks.rs | 12 ++++- src/new_rdata/dnssec/dnskey.rs | 15 ++++++- src/new_rdata/dnssec/ds.rs | 15 ++++++- src/new_rdata/dnssec/nsec.rs | 18 ++++++-- src/new_rdata/dnssec/nsec3.rs | 41 ++++++++++++++++- src/new_rdata/dnssec/rrsig.rs | 37 +++++++++++++++- src/new_rdata/edns.rs | 15 ++++++- src/new_rdata/ipv6.rs | 10 +++++ src/new_rdata/mod.rs | 75 ++++++++++++++++++++++++++++++- 20 files changed, 483 insertions(+), 28 deletions(-) diff --git a/src/new_base/charstr.rs b/src/new_base/charstr.rs index 4b02a6097..6e9c97bc6 100644 --- a/src/new_base/charstr.rs +++ b/src/new_base/charstr.rs @@ -17,9 +17,28 @@ use super::{ #[repr(transparent)] pub struct CharStr { /// The underlying octets. + /// + /// This is at most 255 bytes. It does not include the length octet that + /// precedes the character string when serialized in the wire format. pub octets: [u8], } +//--- Inspection + +impl CharStr { + /// The length of the [`CharStr`]. + /// + /// This is always less than 256 -- it is guaranteed to fit in a [`u8`]. + pub const fn len(&self) -> usize { + self.octets.len() + } + + /// Whether the [`CharStr`] is empty. + pub const fn is_empty(&self) -> bool { + self.octets.is_empty() + } +} + //--- Parsing from DNS messages impl<'a> SplitMessageBytes<'a> for &'a CharStr { diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index df8632884..b3903d90c 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -14,8 +14,8 @@ pub use question::{QClass, QType, Question, UnparsedQuestion}; mod record; pub use record::{ - ParseRecordData, RClass, RType, Record, UnparsedRecord, - UnparsedRecordData, TTL, + CanonicalRecordData, ParseRecordData, RClass, RType, Record, + UnparsedRecord, UnparsedRecordData, TTL, }; //--- Elements of DNS messages diff --git a/src/new_base/record.rs b/src/new_base/record.rs index 9a2813b9a..e435e9dd1 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -1,6 +1,6 @@ //! DNS records. -use core::{borrow::Borrow, fmt, ops::Deref}; +use core::{borrow::Borrow, cmp::Ordering, fmt, ops::Deref}; use super::{ build::{self, BuildIntoMessage, BuildResult}, @@ -258,6 +258,55 @@ impl RType { pub const NSEC3PARAM: Self = Self::new(51); } +//--- Interaction + +impl RType { + /// Whether this type uses lowercased domain names in canonical form. + /// + /// As specified by [RFC 4034, section 6.2] (and updated by [RFC 6840, + /// section 5.1]), the canonical form of the record data of any of the + /// following types will have its domain names lowercased: + /// + /// - [`NS`](RType::NS) + /// - `MD` (obsolete) + /// - `MF` (obsolete) + /// - [`CNAME`](RType::CNAME) + /// - [`SOA`](RType::SOA) + /// - `MB` + /// - `MG` + /// - `MR` + /// - [`PTR`](RType::PTR) + /// - `MINFO` + /// - [`MX`](RType::MX) + /// - `RP` + /// - `AFSDB` + /// - `RT` + /// - `SIG` (obsolete) + /// - `PX` + /// - `NXT` (obsolete) + /// - `NAPTR` + /// - `KX` + /// - `SRV` + /// - `DNAME` + /// - `A6` (obsolete) + /// - [`RRSIG`](RType::RRSIG) + /// + /// [RFC 4034, section 6.2]: https://datatracker.ietf.org/doc/html/rfc4034#section-6.2 + /// [RFC 6840, section 5.1]: https://datatracker.ietf.org/doc/html/rfc6840#section-5.1 + pub const fn uses_lowercase_canonical_form(&self) -> bool { + // TODO: Update this as more types are added. + matches!( + *self, + Self::NS + | Self::CNAME + | Self::SOA + | Self::PTR + | Self::MX + | Self::RRSIG + ) + } +} + //--- Conversion to and from 'u16' impl From for RType { @@ -421,6 +470,35 @@ pub trait ParseRecordData<'a>: Sized { ) -> Result; } +//----------- CanonicalRecordData -------------------------------------------- + +/// DNSSEC-conformant operations for resource records. +/// +/// As specified by [RFC 4034, section 6], there is a "canonical form" for +/// DNS resource records, used for ordering records and computing signatures. +/// This trait defines operations for working with the canonical form. +/// +/// [RFC 4034, section 6]: https://datatracker.ietf.org/doc/html/rfc4034#section-6 +pub trait CanonicalRecordData: BuildBytes { + /// Serialize record data in the canonical form. + /// + /// This is subtly different from [`BuildBytes`]: for certain special + /// record data types, it causes embedded domain names to be lowercased. + /// By default, it will fall back to [`BuildBytes`]. + fn build_canonical_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.build_bytes(bytes) + } + + /// Compare record data in the canonical form. + /// + /// This is equivalent to serializing both record data instances using + /// [`build_canonical_bytes()`] and comparing the resulting byte strings. + fn cmp_canonical(&self, other: &Self) -> Ordering; +} + //----------- UnparsedRecordData --------------------------------------------- /// Unparsed DNS record data. diff --git a/src/new_rdata/basic/a.rs b/src/new_rdata/basic/a.rs index 14700fdd5..4d5b3af13 100644 --- a/src/new_rdata/basic/a.rs +++ b/src/new_rdata/basic/a.rs @@ -1,3 +1,4 @@ +use core::cmp::Ordering; use core::fmt; use core::net::Ipv4Addr; use core::str::FromStr; @@ -5,6 +6,7 @@ use core::str::FromStr; use domain_macros::*; use crate::new_base::wire::AsBytes; +use crate::new_base::CanonicalRecordData; //----------- A -------------------------------------------------------------- @@ -47,6 +49,14 @@ impl From for Ipv4Addr { } } +//--- Canonical operations + +impl CanonicalRecordData for A { + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.octets.cmp(&other.octets) + } +} + //--- Parsing from a string impl FromStr for A { diff --git a/src/new_rdata/basic/cname.rs b/src/new_rdata/basic/cname.rs index 4774dcba9..035546b7f 100644 --- a/src/new_rdata/basic/cname.rs +++ b/src/new_rdata/basic/cname.rs @@ -1,9 +1,13 @@ +use core::cmp::Ordering; + use domain_macros::*; use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, + name::CanonicalName, parse::ParseMessageBytes, - wire::ParseError, + wire::{ParseError, TruncationError}, + CanonicalRecordData, }; //----------- CName ---------------------------------------------------------- @@ -50,6 +54,21 @@ impl CName { } } +//--- Canonical operations + +impl CanonicalRecordData for CName { + fn build_canonical_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.name.build_lowercased_bytes(bytes) + } + + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.name.cmp_lowercase_composed(&other.name) + } +} + //--- Parsing from DNS messages impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for CName { diff --git a/src/new_rdata/basic/hinfo.rs b/src/new_rdata/basic/hinfo.rs index 3f5674a1f..389a605a7 100644 --- a/src/new_rdata/basic/hinfo.rs +++ b/src/new_rdata/basic/hinfo.rs @@ -1,6 +1,8 @@ +use core::cmp::Ordering; + use domain_macros::*; -use crate::new_base::CharStr; +use crate::new_base::{CanonicalRecordData, CharStr}; //----------- HInfo ---------------------------------------------------------- @@ -28,3 +30,23 @@ impl HInfo<'_> { } } } + +//--- Canonical operations + +impl CanonicalRecordData for HInfo<'_> { + fn cmp_canonical(&self, that: &Self) -> Ordering { + let this = ( + self.cpu.len(), + &self.cpu.octets, + self.os.len(), + &self.os.octets, + ); + let that = ( + that.cpu.len(), + &that.cpu.octets, + that.os.len(), + &that.os.octets, + ); + this.cmp(&that) + } +} diff --git a/src/new_rdata/basic/mx.rs b/src/new_rdata/basic/mx.rs index 6d588bb97..ffd7adfe9 100644 --- a/src/new_rdata/basic/mx.rs +++ b/src/new_rdata/basic/mx.rs @@ -1,9 +1,13 @@ +use core::cmp::Ordering; + use domain_macros::*; use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, + name::CanonicalName, parse::{ParseMessageBytes, SplitMessageBytes}, - wire::{AsBytes, ParseError, U16}, + wire::{AsBytes, BuildBytes, ParseError, TruncationError, U16}, + CanonicalRecordData, }; //----------- Mx ------------------------------------------------------------- @@ -55,6 +59,25 @@ impl Mx { } } +//--- Canonical operations + +impl CanonicalRecordData for Mx { + fn build_canonical_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + let bytes = self.preference.build_bytes(bytes)?; + let bytes = self.exchange.build_lowercased_bytes(bytes)?; + Ok(bytes) + } + + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.preference.cmp(&other.preference).then_with(|| { + self.exchange.cmp_lowercase_composed(&other.exchange) + }) + } +} + //--- Parsing from DNS messages impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Mx { diff --git a/src/new_rdata/basic/ns.rs b/src/new_rdata/basic/ns.rs index e2e2b0d7b..bb2b06788 100644 --- a/src/new_rdata/basic/ns.rs +++ b/src/new_rdata/basic/ns.rs @@ -1,9 +1,13 @@ +use core::cmp::Ordering; + use domain_macros::*; use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, + name::CanonicalName, parse::ParseMessageBytes, - wire::ParseError, + wire::{ParseError, TruncationError}, + CanonicalRecordData, }; //----------- Ns ------------------------------------------------------------- @@ -50,6 +54,21 @@ impl Ns { } } +//--- Canonical operations + +impl CanonicalRecordData for Ns { + fn build_canonical_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.name.build_lowercased_bytes(bytes) + } + + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.name.cmp_lowercase_composed(&other.name) + } +} + //--- Parsing from DNS messages impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Ns { diff --git a/src/new_rdata/basic/ptr.rs b/src/new_rdata/basic/ptr.rs index 7b659d58a..1102274ff 100644 --- a/src/new_rdata/basic/ptr.rs +++ b/src/new_rdata/basic/ptr.rs @@ -1,9 +1,13 @@ +use core::cmp::Ordering; + use domain_macros::*; use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, + name::CanonicalName, parse::ParseMessageBytes, - wire::ParseError, + wire::{ParseError, TruncationError}, + CanonicalRecordData, }; //----------- Ptr ------------------------------------------------------------ @@ -50,6 +54,21 @@ impl Ptr { } } +//--- Canonical operations + +impl CanonicalRecordData for Ptr { + fn build_canonical_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + self.name.build_lowercased_bytes(bytes) + } + + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.name.cmp_lowercase_composed(&other.name) + } +} + //--- Parsing from DNS messages impl<'a, N: ParseMessageBytes<'a>> ParseMessageBytes<'a> for Ptr { diff --git a/src/new_rdata/basic/soa.rs b/src/new_rdata/basic/soa.rs index 604648c95..3006f430d 100644 --- a/src/new_rdata/basic/soa.rs +++ b/src/new_rdata/basic/soa.rs @@ -1,10 +1,13 @@ +use core::cmp::Ordering; + use domain_macros::*; use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, + name::CanonicalName, parse::{ParseMessageBytes, SplitMessageBytes}, - wire::{AsBytes, ParseError, U32}, - Serial, + wire::{AsBytes, BuildBytes, ParseError, TruncationError, U32}, + CanonicalRecordData, Serial, }; //----------- Soa ------------------------------------------------------------ @@ -77,6 +80,35 @@ impl Soa { } } +//--- Canonical operations + +impl CanonicalRecordData for Soa { + fn build_canonical_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + let bytes = self.mname.build_lowercased_bytes(bytes)?; + let bytes = self.rname.build_lowercased_bytes(bytes)?; + let bytes = self.serial.build_bytes(bytes)?; + let bytes = self.refresh.build_bytes(bytes)?; + let bytes = self.retry.build_bytes(bytes)?; + let bytes = self.expire.build_bytes(bytes)?; + let bytes = self.minimum.build_bytes(bytes)?; + Ok(bytes) + } + + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.mname + .cmp_lowercase_composed(&other.mname) + .then_with(|| self.rname.cmp_lowercase_composed(&other.rname)) + .then_with(|| self.serial.as_bytes().cmp(other.serial.as_bytes())) + .then_with(|| self.refresh.cmp(&other.refresh)) + .then_with(|| self.retry.cmp(&other.retry)) + .then_with(|| self.expire.cmp(&other.expire)) + .then_with(|| self.minimum.cmp(&other.minimum)) + } +} + //--- Parsing from DNS messages impl<'a, N: SplitMessageBytes<'a>> ParseMessageBytes<'a> for Soa { diff --git a/src/new_rdata/basic/txt.rs b/src/new_rdata/basic/txt.rs index 611492149..77df69d7e 100644 --- a/src/new_rdata/basic/txt.rs +++ b/src/new_rdata/basic/txt.rs @@ -1,11 +1,11 @@ -use core::fmt; +use core::{cmp::Ordering, fmt}; use domain_macros::*; use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, wire::{ParseBytesByRef, ParseError, SplitBytes}, - CharStr, + CanonicalRecordData, CharStr, }; //----------- Txt ------------------------------------------------------------ @@ -38,6 +38,14 @@ impl Txt { } } +//--- Canonical operations + +impl CanonicalRecordData for Txt { + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.content.cmp(&other.content) + } +} + //--- Building into DNS messages impl BuildIntoMessage for Txt { diff --git a/src/new_rdata/basic/wks.rs b/src/new_rdata/basic/wks.rs index 1facb6f8d..4e337433f 100644 --- a/src/new_rdata/basic/wks.rs +++ b/src/new_rdata/basic/wks.rs @@ -1,8 +1,8 @@ -use core::fmt; +use core::{cmp::Ordering, fmt}; use domain_macros::*; -use crate::new_base::wire::AsBytes; +use crate::new_base::{wire::AsBytes, CanonicalRecordData}; use super::A; @@ -22,6 +22,14 @@ pub struct Wks { pub ports: [u8], } +//--- Canonical operations + +impl CanonicalRecordData for Wks { + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.as_bytes().cmp(other.as_bytes()) + } +} + //--- Formatting impl fmt::Debug for Wks { diff --git a/src/new_rdata/dnssec/dnskey.rs b/src/new_rdata/dnssec/dnskey.rs index 52b0abf2f..c970429d3 100644 --- a/src/new_rdata/dnssec/dnskey.rs +++ b/src/new_rdata/dnssec/dnskey.rs @@ -1,8 +1,11 @@ -use core::fmt; +use core::{cmp::Ordering, fmt}; use domain_macros::*; -use crate::new_base::wire::U16; +use crate::new_base::{ + wire::{AsBytes, U16}, + CanonicalRecordData, +}; use super::SecAlg; @@ -27,6 +30,14 @@ pub struct DNSKey { pub key: [u8], } +//--- Canonical operations + +impl CanonicalRecordData for DNSKey { + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.as_bytes().cmp(other.as_bytes()) + } +} + //----------- DNSKeyFlags ---------------------------------------------------- /// Flags describing a [`DNSKey`]. diff --git a/src/new_rdata/dnssec/ds.rs b/src/new_rdata/dnssec/ds.rs index f4c265133..aae784864 100644 --- a/src/new_rdata/dnssec/ds.rs +++ b/src/new_rdata/dnssec/ds.rs @@ -1,8 +1,11 @@ -use core::fmt; +use core::{cmp::Ordering, fmt}; use domain_macros::*; -use crate::new_base::wire::U16; +use crate::new_base::{ + wire::{AsBytes, U16}, + CanonicalRecordData, +}; use super::SecAlg; @@ -27,6 +30,14 @@ pub struct Ds { pub digest: [u8], } +//--- Canonical operations + +impl CanonicalRecordData for Ds { + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.as_bytes().cmp(other.as_bytes()) + } +} + //----------- DigestType ----------------------------------------------------- /// A cryptographic digest algorithm. diff --git a/src/new_rdata/dnssec/nsec.rs b/src/new_rdata/dnssec/nsec.rs index 24b8a45e7..9df89d276 100644 --- a/src/new_rdata/dnssec/nsec.rs +++ b/src/new_rdata/dnssec/nsec.rs @@ -1,11 +1,11 @@ -use core::{fmt, mem}; +use core::{cmp::Ordering, fmt, mem}; use domain_macros::*; use crate::new_base::{ - name::Name, - wire::{ParseBytesByRef, ParseError}, - RType, + name::{CanonicalName, Name}, + wire::{AsBytes, ParseBytesByRef, ParseError}, + CanonicalRecordData, RType, }; //----------- NSec ----------------------------------------------------------- @@ -35,6 +35,16 @@ impl NSec<'_> { } } +//--- Canonical operations + +impl CanonicalRecordData for NSec<'_> { + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.next + .cmp_composed(other.next) + .then_with(|| self.types.as_bytes().cmp(other.types.as_bytes())) + } +} + //----------- TypeBitmaps ---------------------------------------------------- /// A bitmap of DNS record types. diff --git a/src/new_rdata/dnssec/nsec3.rs b/src/new_rdata/dnssec/nsec3.rs index 56877f4ef..1e0358157 100644 --- a/src/new_rdata/dnssec/nsec3.rs +++ b/src/new_rdata/dnssec/nsec3.rs @@ -1,8 +1,11 @@ -use core::fmt; +use core::{cmp::Ordering, fmt}; use domain_macros::*; -use crate::new_base::wire::{SizePrefixed, U16}; +use crate::new_base::{ + wire::{AsBytes, SizePrefixed, U16}, + CanonicalRecordData, +}; use super::TypeBitmaps; @@ -49,6 +52,34 @@ impl NSec3<'_> { } } +//--- Canonical operations + +impl CanonicalRecordData for NSec3<'_> { + fn cmp_canonical(&self, that: &Self) -> Ordering { + let this = ( + self.algorithm, + self.flags.as_bytes(), + self.iterations, + self.salt.len(), + self.salt, + self.next.len(), + self.next, + self.types.as_bytes(), + ); + let that = ( + that.algorithm, + that.flags.as_bytes(), + that.iterations, + that.salt.len(), + that.salt, + that.next.len(), + that.next, + that.types.as_bytes(), + ); + this.cmp(&that) + } +} + //----------- NSec3Param ----------------------------------------------------- /// Parameters for computing [`NSec3`] records. @@ -77,6 +108,12 @@ pub struct NSec3Param { pub salt: SizePrefixed, } +impl CanonicalRecordData for NSec3Param { + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.as_bytes().cmp(other.as_bytes()) + } +} + //----------- NSec3HashAlg --------------------------------------------------- /// The hash algorithm used with [`NSec3`] records. diff --git a/src/new_rdata/dnssec/rrsig.rs b/src/new_rdata/dnssec/rrsig.rs index a717caa9e..e63dfcc74 100644 --- a/src/new_rdata/dnssec/rrsig.rs +++ b/src/new_rdata/dnssec/rrsig.rs @@ -1,6 +1,12 @@ +use core::cmp::Ordering; + use domain_macros::*; -use crate::new_base::{name::Name, wire::U16, RType, Serial, TTL}; +use crate::new_base::{ + name::{CanonicalName, Name}, + wire::{AsBytes, U16}, + CanonicalRecordData, RType, Serial, TTL, +}; use super::SecAlg; @@ -52,3 +58,32 @@ impl RRSig<'_> { } } } + +//--- Canonical operations + +impl CanonicalRecordData for RRSig<'_> { + fn cmp_canonical(&self, that: &Self) -> Ordering { + let this_initial = ( + self.rtype, + self.algorithm, + self.labels, + self.ttl, + self.expiration.as_bytes(), + self.inception.as_bytes(), + self.keytag, + ); + let that_initial = ( + that.rtype, + that.algorithm, + that.labels, + that.ttl, + that.expiration.as_bytes(), + that.inception.as_bytes(), + that.keytag, + ); + this_initial + .cmp(&that_initial) + .then_with(|| self.signer.cmp_lowercase_composed(that.signer)) + .then_with(|| self.signature.cmp(that.signature)) + } +} diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index 9389d6984..c9545d569 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -2,12 +2,15 @@ //! //! See [RFC 6891](https://datatracker.ietf.org/doc/html/rfc6891). -use core::{fmt, iter::FusedIterator}; +use core::{cmp::Ordering, fmt, iter::FusedIterator}; use domain_macros::*; use crate::{ - new_base::wire::{ParseError, SplitBytes}, + new_base::{ + wire::{ParseError, SplitBytes}, + CanonicalRecordData, + }, new_edns::EdnsOption, }; @@ -40,6 +43,14 @@ impl Opt { } } +//--- Canonical operations + +impl CanonicalRecordData for Opt { + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.contents.cmp(&other.contents) + } +} + //--- Formatting impl fmt::Debug for Opt { diff --git a/src/new_rdata/ipv6.rs b/src/new_rdata/ipv6.rs index f91ae5e7b..feacd06d4 100644 --- a/src/new_rdata/ipv6.rs +++ b/src/new_rdata/ipv6.rs @@ -2,6 +2,7 @@ //! //! See [RFC 3596](https://datatracker.ietf.org/doc/html/rfc3596). +use core::cmp::Ordering; #[cfg(feature = "std")] use core::{fmt, str::FromStr}; @@ -13,6 +14,7 @@ use domain_macros::*; use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, wire::AsBytes, + CanonicalRecordData, }; //----------- Aaaa ----------------------------------------------------------- @@ -58,6 +60,14 @@ impl From for Ipv6Addr { } } +//--- Canonical operations + +impl CanonicalRecordData for Aaaa { + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.octets.cmp(&other.octets) + } +} + //--- Parsing from a string #[cfg(feature = "std")] diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 82c408259..c76c93c83 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -1,15 +1,18 @@ //! Record data types. +use core::cmp::Ordering; + use domain_macros::*; use crate::new_base::{ build::{self, BuildIntoMessage, BuildResult}, + name::CanonicalName, parse::{ParseMessageBytes, SplitMessageBytes}, wire::{ AsBytes, BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError, }, - ParseRecordData, RType, + CanonicalRecordData, ParseRecordData, RType, }; //----------- Concrete record data types ------------------------------------- @@ -210,6 +213,66 @@ impl<'a, N> RecordData<'a, N> { } } +//--- Canonical operations + +impl CanonicalRecordData for RecordData<'_, N> { + fn build_canonical_bytes<'b>( + &self, + bytes: &'b mut [u8], + ) -> Result<&'b mut [u8], TruncationError> { + match self { + Self::A(r) => r.build_canonical_bytes(bytes), + Self::Ns(r) => r.build_canonical_bytes(bytes), + Self::CName(r) => r.build_canonical_bytes(bytes), + Self::Soa(r) => r.build_canonical_bytes(bytes), + Self::Wks(r) => r.build_canonical_bytes(bytes), + Self::Ptr(r) => r.build_canonical_bytes(bytes), + Self::HInfo(r) => r.build_canonical_bytes(bytes), + Self::Mx(r) => r.build_canonical_bytes(bytes), + Self::Txt(r) => r.build_canonical_bytes(bytes), + Self::Aaaa(r) => r.build_canonical_bytes(bytes), + Self::Opt(r) => r.build_canonical_bytes(bytes), + Self::Ds(r) => r.build_canonical_bytes(bytes), + Self::RRSig(r) => r.build_canonical_bytes(bytes), + Self::NSec(r) => r.build_canonical_bytes(bytes), + Self::DNSKey(r) => r.build_canonical_bytes(bytes), + Self::NSec3(r) => r.build_canonical_bytes(bytes), + Self::NSec3Param(r) => r.build_canonical_bytes(bytes), + Self::Unknown(_, rd) => rd.build_canonical_bytes(bytes), + } + } + + fn cmp_canonical(&self, other: &Self) -> Ordering { + self.rtype() + .cmp(&other.rtype()) + .then_with(|| match (self, other) { + (Self::A(l), Self::A(r)) => l.cmp_canonical(r), + (Self::Ns(l), Self::Ns(r)) => l.cmp_canonical(r), + (Self::CName(l), Self::CName(r)) => l.cmp_canonical(r), + (Self::Soa(l), Self::Soa(r)) => l.cmp_canonical(r), + (Self::Wks(l), Self::Wks(r)) => l.cmp_canonical(r), + (Self::Ptr(l), Self::Ptr(r)) => l.cmp_canonical(r), + (Self::HInfo(l), Self::HInfo(r)) => l.cmp_canonical(r), + (Self::Mx(l), Self::Mx(r)) => l.cmp_canonical(r), + (Self::Txt(l), Self::Txt(r)) => l.cmp_canonical(r), + (Self::Aaaa(l), Self::Aaaa(r)) => l.cmp_canonical(r), + (Self::Opt(l), Self::Opt(r)) => l.cmp_canonical(r), + (Self::Ds(l), Self::Ds(r)) => l.cmp_canonical(r), + (Self::RRSig(l), Self::RRSig(r)) => l.cmp_canonical(r), + (Self::NSec(l), Self::NSec(r)) => l.cmp_canonical(r), + (Self::DNSKey(l), Self::DNSKey(r)) => l.cmp_canonical(r), + (Self::NSec3(l), Self::NSec3(r)) => l.cmp_canonical(r), + (Self::NSec3Param(l), Self::NSec3Param(r)) => { + l.cmp_canonical(r) + } + (Self::Unknown(_, l), Self::Unknown(_, r)) => { + l.cmp_canonical(r) + } + _ => unreachable!(), + }) + } +} + //--- Parsing record data impl<'a, N> ParseRecordData<'a> for RecordData<'a, N> @@ -389,3 +452,13 @@ impl UnknownRecordData { unsafe { Self::parse_bytes_by_mut(bytes).unwrap_unchecked() } } } + +//--- Canonical operations + +impl CanonicalRecordData for UnknownRecordData { + fn cmp_canonical(&self, other: &Self) -> Ordering { + // Since this is not a well-known record data type, embedded domain + // names do not need to be lowercased. + self.octets.cmp(&other.octets) + } +} From 2f629939cff1398cd2c7e2a2b2a1ace7258df675 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 20 Mar 2025 14:42:24 +0100 Subject: [PATCH 159/167] [new_base/record] Fix broken doc link --- src/new_base/record.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/new_base/record.rs b/src/new_base/record.rs index e435e9dd1..4309d7701 100644 --- a/src/new_base/record.rs +++ b/src/new_base/record.rs @@ -496,6 +496,8 @@ pub trait CanonicalRecordData: BuildBytes { /// /// This is equivalent to serializing both record data instances using /// [`build_canonical_bytes()`] and comparing the resulting byte strings. + /// + /// [`build_canonical_bytes()`]: Self::build_canonical_bytes() fn cmp_canonical(&self, other: &Self) -> Ordering; } From 11560a3c3ffec52fc75a97f3fdccfe122387c55d Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 20 Mar 2025 14:42:34 +0100 Subject: [PATCH 160/167] [new_base] Add module 'compat' --- src/new_base/mod.rs | 167 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index b3903d90c..493c9f82b 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -33,3 +33,170 @@ pub use serial::Serial; pub mod build; pub mod parse; pub mod wire; + +//--- Compatibility exports + +/// A compatibility module with [`domain::base`]. +/// +/// This re-exports a large part of the `new_base` API surface using the same +/// import paths as the old `base` module. It is a stopgap measure to help +/// users port existing code over to `new_base`. Every export comes with a +/// deprecation message to help users switch to the right tools. +pub mod compat { + #![allow(deprecated)] + + #[deprecated = "use 'crate::new_base::HeaderFlags' instead."] + pub use header::Flags; + + #[deprecated = "use 'crate::new_base::Header' instead."] + pub use header::HeaderSection; + + #[deprecated = "use 'crate::new_base::SectionCounts' instead."] + pub use header::HeaderCounts; + + #[deprecated = "use 'crate::new_base::RType' instead."] + pub use iana::rtype::Rtype; + + #[deprecated = "use 'crate::new_base::Message' instead."] + pub use message::Message; + + #[deprecated = "use 'crate::new_base::build::MessageBuilder' instead."] + pub use message_builder::MessageBuilder; + + #[deprecated = "use 'crate::new_base::name::Label' instead."] + pub use name::Label; + + #[deprecated = "use 'crate::new_base::name::Name' instead."] + pub use name::Name; + + #[deprecated = "use 'crate::new_base::Question' instead."] + pub use question::Question; + + #[deprecated = "use 'crate::new_base::ParseRecordData' instead."] + pub use rdata::ParseRecordData; + + #[deprecated = "use 'crate::new_rdata::UnknownRecordData' instead."] + pub use rdata::UnknownRecordData; + + #[deprecated = "use 'crate::new_base::Record' instead."] + pub use record::Record; + + #[deprecated = "use 'crate::new_base::TTL' instead."] + pub use record::Ttl; + + #[deprecated = "use 'crate::new_base::Serial' instead."] + pub use serial::Serial; + + pub mod header { + #[deprecated = "use 'crate::new_base::HeaderFlags' instead."] + pub use crate::new_base::HeaderFlags as Flags; + + #[deprecated = "use 'crate::new_base::Header' instead."] + pub use crate::new_base::Header as HeaderSection; + + #[deprecated = "use 'crate::new_base::SectionCounts' instead."] + pub use crate::new_base::SectionCounts as HeaderCounts; + } + + pub mod iana { + #[deprecated = "use 'crate::new_base::RClass' instead."] + pub use class::Class; + + #[deprecated = "use 'crate::new_rdata::DigestType' instead."] + pub use digestalg::DigestAlg; + + #[deprecated = "use 'crate::new_rdata::NSec3HashAlg' instead."] + pub use nsec3::Nsec3HashAlg; + + #[deprecated = "use 'crate::new_edns::OptionCode' instead."] + pub use opt::OptionCode; + + #[deprecated = "for now, just use 'u8', but a better API is coming."] + pub use rcode::Rcode; + + #[deprecated = "use 'crate::new_base::RType' instead."] + pub use rtype::Rtype; + + #[deprecated = "use 'crate::new_rdata::SecAlg' instead."] + pub use secalg::SecAlg; + + pub mod class { + #[deprecated = "use 'crate::new_base::RClass' instead."] + pub use crate::new_base::RClass as Class; + } + + pub mod digestalg { + #[deprecated = "use 'crate::new_rdata::DigestType' instead."] + pub use crate::new_rdata::DigestType as DigestAlg; + } + + pub mod nsec3 { + #[deprecated = "use 'crate::new_rdata::NSec3HashAlg' instead."] + pub use crate::new_rdata::NSec3HashAlg as Nsec3HashAlg; + } + + pub mod opt { + #[deprecated = "use 'crate::new_edns::OptionCode' instead."] + pub use crate::new_edns::OptionCode; + } + + pub mod rcode { + #[deprecated = "for now, just use 'u8', but a better API is coming."] + pub use u8 as Rcode; + } + + pub mod rtype { + #[deprecated = "use 'crate::new_base::RType' instead."] + pub use crate::new_base::RType as Rtype; + } + + pub mod secalg { + #[deprecated = "use 'crate::new_rdata::SecAlg' instead."] + pub use crate::new_rdata::SecAlg; + } + } + + pub mod message { + #[deprecated = "use 'crate::new_base::Message' instead."] + pub use crate::new_base::Message; + } + + pub mod message_builder { + #[deprecated = "use 'crate::new_base::build::MessageBuilder' instead."] + pub use crate::new_base::build::MessageBuilder; + } + + pub mod name { + #[deprecated = "use 'crate::new_base::name::Label' instead."] + pub use crate::new_base::name::Label; + + #[deprecated = "use 'crate::new_base::name::Name' instead."] + pub use crate::new_base::name::Name; + } + + pub mod question { + #[deprecated = "use 'crate::new_base::Question' instead."] + pub use crate::new_base::Question; + } + + pub mod rdata { + #[deprecated = "use 'crate::new_base::ParseRecordData' instead."] + pub use crate::new_base::ParseRecordData; + + #[deprecated = "use 'crate::new_rdata::UnknownRecordData' instead."] + pub use crate::new_rdata::UnknownRecordData; + } + + pub mod record { + #[deprecated = "use 'crate::new_base::Record' instead."] + pub use crate::new_base::Record; + + #[deprecated = "use 'crate::new_base::TTL' instead."] + pub use crate::new_base::TTL as Ttl; + } + + pub mod serial { + #[deprecated = "use 'crate::new_base::Serial' instead."] + pub use crate::new_base::Serial; + } +} From 9c67fc6a69c377b63cc0f70e28cecad28a4bfc08 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 20 Mar 2025 14:43:49 +0100 Subject: [PATCH 161/167] [new_rdata/edns] Add 'Opt::EMPTY' --- src/new_rdata/edns.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index c9545d569..a2a4c881d 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -34,6 +34,14 @@ pub struct Opt { contents: [u8], } +//--- Associated Constants + +impl Opt { + /// Empty OPT record data. + pub const EMPTY: &'static Self = + unsafe { core::mem::transmute(&[] as &[u8]) }; +} + //--- Inspection impl Opt { From e81f1bce6decc10668df6be4347b1c6ace7b6616 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 20 Mar 2025 14:58:59 +0100 Subject: [PATCH 162/167] [new_base/build] Make 'MessageBuilder::finish()' return a mutable ref --- src/new_base/build/message.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 87767528f..5ef3eebd0 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -88,8 +88,11 @@ impl MessageBuilder<'_, '_> { impl<'b> MessageBuilder<'b, '_> { /// End the builder, returning the built message. - pub fn finish(self) -> &'b Message { - self.message.slice_to(self.context.size) + /// + /// The returned message is valid, but it can be modified by the caller + /// arbitrarily; avoid modifying the message beyond the header. + pub fn finish(self) -> &'b mut Message { + self.message.slice_to_mut(self.context.size) } /// Reborrow the builder with a shorter lifetime. From 7090d7b2c0b0db455d1de0f462770268d1fbb0ab Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 20 Mar 2025 15:21:52 +0100 Subject: [PATCH 163/167] [new_base/build] Add 'must_use' in a few vital places This helps prevent users from accidentally forgetting to commit the components of the message they're building. --- src/new_base/build/message.rs | 9 +++++++++ src/new_base/build/question.rs | 1 + src/new_base/build/record.rs | 1 + 3 files changed, 11 insertions(+) diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 5ef3eebd0..83cba7f1a 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -39,6 +39,7 @@ impl<'b, 'c> MessageBuilder<'b, 'c> { /// /// Panics if the buffer is less than 12 bytes long (which is the minimum /// possible size for a DNS message). + #[must_use] pub fn new( buffer: &'b mut [u8], context: &'c mut BuilderContext, @@ -54,16 +55,19 @@ impl<'b, 'c> MessageBuilder<'b, 'c> { impl MessageBuilder<'_, '_> { /// The message header. + #[must_use] pub fn header(&self) -> &Header { &self.message.header } /// The message header, mutably. + #[must_use] pub fn header_mut(&mut self) -> &mut Header { &mut self.message.header } /// The message built thus far. + #[must_use] pub fn message(&self) -> &Message { self.message.slice_to(self.context.size) } @@ -74,11 +78,13 @@ impl MessageBuilder<'_, '_> { /// /// The caller must not modify any compressed names among these bytes. /// This can invalidate name compression state. + #[must_use] pub unsafe fn message_mut(&mut self) -> &mut Message { self.message.slice_to_mut(self.context.size) } /// The builder context. + #[must_use] pub fn context(&self) -> &BuilderContext { self.context } @@ -91,11 +97,13 @@ impl<'b> MessageBuilder<'b, '_> { /// /// The returned message is valid, but it can be modified by the caller /// arbitrarily; avoid modifying the message beyond the header. + #[must_use] pub fn finish(self) -> &'b mut Message { self.message.slice_to_mut(self.context.size) } /// Reborrow the builder with a shorter lifetime. + #[must_use] pub fn reborrow(&mut self) -> MessageBuilder<'_, '_> { MessageBuilder { message: self.message, @@ -138,6 +146,7 @@ impl<'b> MessageBuilder<'b, '_> { } /// Obtain a [`Builder`]. + #[must_use] pub(super) fn builder(&mut self, start: usize) -> Builder<'_> { debug_assert!(start <= self.context.size); unsafe { diff --git a/src/new_base/build/question.rs b/src/new_base/build/question.rs index addd72d91..8f6f04c3c 100644 --- a/src/new_base/build/question.rs +++ b/src/new_base/build/question.rs @@ -17,6 +17,7 @@ use super::{BuildCommitted, BuildIntoMessage, MessageBuilder, MessageState}; /// appended to a message (using a [`MessageBuilder`]). It can be used to /// inspect the question's fields, to replace it with a new question, and to /// commit (finish building) or cancel (remove) the question. +#[must_use = "A 'QuestionBuilder' must be explicitly committed, else all added content will be lost"] pub struct QuestionBuilder<'b> { /// The underlying message builder. builder: MessageBuilder<'b, 'b>, diff --git a/src/new_base/build/record.rs b/src/new_base/build/record.rs index e43495cac..327f289d8 100644 --- a/src/new_base/build/record.rs +++ b/src/new_base/build/record.rs @@ -21,6 +21,7 @@ use super::{ /// a DNS message (using a [`MessageBuilder`]). It can be used to inspect the /// record, to (re)write the record data, and to commit (finish building) or /// cancel (remove) the record. +#[must_use = "A 'RecordBuilder' must be explicitly committed, else all added content will be lost"] pub struct RecordBuilder<'b> { /// The underlying message builder. builder: MessageBuilder<'b, 'b>, From 4a9cf40488cf551114e99c93cd2fc91547bc2825 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 20 Mar 2025 15:22:26 +0100 Subject: [PATCH 164/167] [new_base/compat] Remove hard-to-port re-exports --- src/new_base/mod.rs | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/new_base/mod.rs b/src/new_base/mod.rs index 493c9f82b..3e22e27fc 100644 --- a/src/new_base/mod.rs +++ b/src/new_base/mod.rs @@ -57,12 +57,6 @@ pub mod compat { #[deprecated = "use 'crate::new_base::RType' instead."] pub use iana::rtype::Rtype; - #[deprecated = "use 'crate::new_base::Message' instead."] - pub use message::Message; - - #[deprecated = "use 'crate::new_base::build::MessageBuilder' instead."] - pub use message_builder::MessageBuilder; - #[deprecated = "use 'crate::new_base::name::Label' instead."] pub use name::Label; @@ -156,16 +150,6 @@ pub mod compat { } } - pub mod message { - #[deprecated = "use 'crate::new_base::Message' instead."] - pub use crate::new_base::Message; - } - - pub mod message_builder { - #[deprecated = "use 'crate::new_base::build::MessageBuilder' instead."] - pub use crate::new_base::build::MessageBuilder; - } - pub mod name { #[deprecated = "use 'crate::new_base::name::Label' instead."] pub use crate::new_base::name::Label; From a5a692f6af6221fd2896d2c7130dc3ec1d7b4722 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 20 Mar 2025 15:32:36 +0100 Subject: [PATCH 165/167] [new_edns] Add 'EdnsRecord::clone_to_bump()' --- src/new_edns/cookie.rs | 9 ++++++++- src/new_edns/ext_err.rs | 2 +- src/new_edns/mod.rs | 29 ++++++++++++++++++++++++++++- 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/new_edns/cookie.rs b/src/new_edns/cookie.rs index 77c810be5..32a770a3a 100644 --- a/src/new_edns/cookie.rs +++ b/src/new_edns/cookie.rs @@ -131,7 +131,14 @@ impl fmt::Display for ClientCookie { /// A DNS cookie. #[derive( - Debug, PartialEq, Eq, Hash, AsBytes, BuildBytes, ParseBytesByRef, + Debug, + PartialEq, + Eq, + Hash, + AsBytes, + BuildBytes, + ParseBytesByRef, + UnsizedClone, )] #[repr(C)] pub struct Cookie { diff --git a/src/new_edns/ext_err.rs b/src/new_edns/ext_err.rs index 7613afd8e..510ee4963 100644 --- a/src/new_edns/ext_err.rs +++ b/src/new_edns/ext_err.rs @@ -11,7 +11,7 @@ use crate::new_base::wire::U16; //----------- ExtError ------------------------------------------------------- /// An extended DNS error. -#[derive(AsBytes, BuildBytes, ParseBytesByRef)] +#[derive(AsBytes, BuildBytes, ParseBytesByRef, UnsizedClone)] #[repr(C)] pub struct ExtError { /// The error code. diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index fa16a241d..67edb2a80 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -219,6 +219,33 @@ impl EdnsOption<'_> { Self::Unknown(code, _) => *code, } } + + /// Copy referenced data into the given [`Bump`](bumpalo::Bump) allocator. + #[cfg(feature = "bumpalo")] + pub fn clone_to_bump<'r>( + &self, + bump: &'r bumpalo::Bump, + ) -> EdnsOption<'r> { + use crate::utils::clone_to_bump; + + match *self { + EdnsOption::ClientCookie(&client_cookie) => { + EdnsOption::ClientCookie(bump.alloc(client_cookie)) + } + EdnsOption::Cookie(cookie) => { + EdnsOption::Cookie(clone_to_bump(cookie, bump)) + } + EdnsOption::ExtError(ext_error) => { + EdnsOption::ExtError(clone_to_bump(ext_error, bump)) + } + EdnsOption::Unknown(option_code, unknown_option) => { + EdnsOption::Unknown( + option_code, + clone_to_bump(unknown_option, bump), + ) + } + } + } } //--- Parsing from bytes @@ -338,7 +365,7 @@ impl fmt::Debug for OptionCode { //----------- UnknownOption -------------------------------------------------- /// Data for an unknown Extended DNS option. -#[derive(Debug, AsBytes, BuildBytes, ParseBytesByRef)] +#[derive(Debug, AsBytes, BuildBytes, ParseBytesByRef, UnsizedClone)] #[repr(transparent)] pub struct UnknownOption { /// The unparsed option data. From 285019da1673ebfeb3d4e91b8a138ff376516e99 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Thu, 20 Mar 2025 15:48:29 +0100 Subject: [PATCH 166/167] [new_base/message] Make 'HeaderFlags' setters modify in place --- src/new_base/build/message.rs | 3 +-- src/new_base/message.rs | 14 +++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/new_base/build/message.rs b/src/new_base/build/message.rs index 83cba7f1a..de38af4fa 100644 --- a/src/new_base/build/message.rs +++ b/src/new_base/build/message.rs @@ -140,8 +140,7 @@ impl<'b> MessageBuilder<'b, '_> { /// /// This will remove all message contents and mark it as truncated. pub fn truncate(&mut self) { - self.message.header.flags = - self.message.header.flags.set_truncated(true); + self.message.header.flags.set_truncated(true); *self.context = BuilderContext::default(); } diff --git a/src/new_base/message.rs b/src/new_base/message.rs index 7df0c47e1..358a75e2a 100644 --- a/src/new_base/message.rs +++ b/src/new_base/message.rs @@ -150,7 +150,7 @@ impl HeaderFlags { } /// Set the specified flag bit. - fn set_flag(mut self, pos: u32, value: bool) -> Self { + fn set_flag(&mut self, pos: u32, value: bool) -> &mut Self { self.inner &= !(1 << pos); self.inner |= (value as u16) << pos; self @@ -182,7 +182,7 @@ impl HeaderFlags { } /// Construct a query. - pub fn query(mut self, opcode: u8) -> Self { + pub fn query(&mut self, opcode: u8) -> &mut Self { assert!(opcode < 16); self.inner &= !(0xF << 11); self.inner |= (opcode as u16) << 11; @@ -190,7 +190,7 @@ impl HeaderFlags { } /// Construct a response. - pub fn respond(mut self, rcode: u8) -> Self { + pub fn respond(&mut self, rcode: u8) -> &mut Self { assert!(rcode < 16); self.inner &= !0xF; self.inner |= rcode as u16; @@ -203,7 +203,7 @@ impl HeaderFlags { } /// Mark this as an authoritative answer. - pub fn set_authoritative(self, value: bool) -> Self { + pub fn set_authoritative(&mut self, value: bool) -> &mut Self { self.set_flag(10, value) } @@ -213,7 +213,7 @@ impl HeaderFlags { } /// Mark this message as truncated. - pub fn set_truncated(self, value: bool) -> Self { + pub fn set_truncated(&mut self, value: bool) -> &mut Self { self.set_flag(9, value) } @@ -223,7 +223,7 @@ impl HeaderFlags { } /// Direct the server to query recursively. - pub fn request_recursion(self, value: bool) -> Self { + pub fn request_recursion(&mut self, value: bool) -> &mut Self { self.set_flag(8, value) } @@ -233,7 +233,7 @@ impl HeaderFlags { } /// Indicate support for recursive queries. - pub fn support_recursion(self, value: bool) -> Self { + pub fn support_recursion(&mut self, value: bool) -> &mut Self { self.set_flag(7, value) } } From c88b36b001af163d23129681bf4ab8c65e273935 Mon Sep 17 00:00:00 2001 From: arya dradjica Date: Fri, 21 Mar 2025 09:55:33 +0100 Subject: [PATCH 167/167] [new_base/build] Fix broken doc test --- src/new_base/build/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/new_base/build/mod.rs b/src/new_base/build/mod.rs index 871d20571..a028e3b3d 100644 --- a/src/new_base/build/mod.rs +++ b/src/new_base/build/mod.rs @@ -24,7 +24,7 @@ //! // Select a randomized ID here. //! id: U16::new(1234), //! // A recursive query for authoritative data. -//! flags: HeaderFlags::default() +//! flags: *HeaderFlags::default() //! .query(0) //! .set_authoritative(true) //! .request_recursion(true),