diff --git a/src/net/server/connection.rs b/src/net/server/connection.rs index a09609a05..6e81292dc 100644 --- a/src/net/server/connection.rs +++ b/src/net/server/connection.rs @@ -1,14 +1,16 @@ //! Support for stream based connections. +use core::future::Future; use core::ops::{ControlFlow, Deref}; +use core::pin::Pin; use core::time::Duration; +use std::boxed::Box; use std::fmt::Display; use std::io; use std::net::SocketAddr; use std::sync::Arc; use arc_swap::ArcSwap; -use futures_util::StreamExt; use log::{log_enabled, Level}; use octseq::Octets; use tokio::io::{ @@ -20,16 +22,18 @@ use tokio::time::Instant; use tokio::time::{sleep_until, timeout}; use tracing::{debug, error, trace, warn}; +use crate::base::iana::OptRcode; use crate::base::message_builder::AdditionalBuilder; use crate::base::wire::Composer; use crate::base::{Message, StreamTarget}; use crate::net::server::buf::BufSource; use crate::net::server::message::Request; use crate::net::server::metrics::ServerMetrics; -use crate::net::server::service::{Service, ServiceFeedback}; -use crate::net::server::util::to_pcap_text; +use crate::net::server::service::Service; +use crate::net::server::util::{mk_error_response, to_pcap_text}; use crate::utils::config::DefMinMax; +use super::invoker::{InvokerStatus, ServiceInvoker}; use super::message::{NonUdpTransportContext, TransportSpecificContext}; use super::stream::Config as ServerConfig; use super::ServerCommand; @@ -221,9 +225,10 @@ impl Clone for Config { /// A handler for a single stream connection between client and server. pub struct Connection where - Buf: BufSource, - Buf::Output: Send + Sync + Unpin, - Svc: Service + Clone, + Buf: BufSource + Clone + Send + Sync + 'static, + Buf::Output: Octets + Send + Sync + Unpin, + Svc: Service + Clone + Send + Sync + 'static, + Svc::Target: Composer + Default + Send, { /// Flag used by the Drop impl to track if the metric count has to be /// decreased or not. @@ -266,6 +271,9 @@ where /// [`ServerMetrics`] describing the status of the server. metrics: Arc, + + /// Dispatches requests to the service and enqueues responses for sending. + request_dispatcher: ServiceResponseHandler, } /// Creation @@ -273,9 +281,12 @@ where impl Connection where Stream: AsyncRead + AsyncWrite, - Buf: BufSource, + Buf: BufSource + Clone + Send + Sync, Buf::Output: Octets + Send + Sync + Unpin, - Svc: Service + Clone, + Svc: Service + Clone + Send + Sync, + Svc::Target: Composer + Default + Send, + Svc::Stream: Send, + Svc::Future: Send, { /// Creates a new handler for an accepted stream connection. #[must_use] @@ -322,6 +333,12 @@ where // uses of self we have to do while running. let stream_rx = Some(stream_rx); + let request_dispatcher = ServiceResponseHandler::new( + config.clone(), + result_q_tx.clone(), + metrics.clone(), + ); + Self { active: false, buf, @@ -334,6 +351,7 @@ where service, idle_timer, metrics, + request_dispatcher, } } } @@ -346,8 +364,9 @@ where Buf: BufSource + Send + Sync + Clone + 'static, Buf::Output: Octets + Send + Sync + Unpin, Svc: Service + Clone + Send + Sync + 'static, - Svc::Target: Composer + Send, + Svc::Target: Composer + Default + Send, Svc::Stream: Send, + Svc::Future: Send, { /// Start reading requests and writing responses to the stream. /// @@ -363,9 +382,7 @@ where pub async fn run( mut self, command_rx: watch::Receiver>, - ) where - Svc::Future: Send, - { + ) { self.metrics.inc_num_connections(); // Flag that we have to decrease the metric count on Drop. @@ -383,7 +400,7 @@ where Buf: BufSource + Send + Sync + Clone + 'static, Buf::Output: Octets + Send + Sync + Unpin, Svc: Service + Clone + Send + Sync + 'static, - Svc::Target: Composer + Send, + Svc::Target: Composer + Default + Send, Svc::Future: Send, Svc::Stream: Send, { @@ -585,7 +602,7 @@ where if log_enabled!(Level::Trace) { let bytes = msg.as_dgram_slice(); let pcap_text = to_pcap_text(bytes, bytes.len()); - trace!(addr = %self.addr, pcap_text, "Sending response"); + trace!(addr = %self.addr, pcap_text, "Sending {} bytes of response tp {}", self.addr, bytes.len()); } match timeout( @@ -642,6 +659,7 @@ where ) -> Result<(), ConnectionEvent> where Svc::Stream: Send, + Svc::Target: Default, { match res { Ok(buf) => { @@ -663,6 +681,9 @@ where tracing::warn!( "Failed while parsing request message: {err}" ); + // Consider the client to be a threat to us if it is + // sending garbage that we can't parse: disconnect it + // immediately. return Err(ConnectionEvent::DisconnectWithoutFlush); } @@ -673,118 +694,39 @@ where Ok(msg) if msg.header().qr() => { // TO DO: Count this event? trace!("Ignoring received message because it is a reply, not a query."); + let response = + mk_error_response::( + &msg, + OptRcode::FORMERR, + ); + let dispatcher = self.request_dispatcher.clone(); + tokio::spawn(async move { + dispatcher.do_enqueue_response(response).await; + }); } Ok(msg) => { let ctx = NonUdpTransportContext::new(Some( self.config.load().idle_timeout, )); - let ctx = TransportSpecificContext::NonUdp(ctx); + let request = Request::new( self.addr, received_at, msg, - ctx, + TransportSpecificContext::NonUdp(ctx), (), ); - let svc = self.service.clone(); - let result_q_tx = self.result_q_tx.clone(); - let metrics = self.metrics.clone(); - let config = self.config.clone(); - trace!( "Spawning task to handle new message with id {}", request.message().header().id() ); + + let mut dispatcher = self.request_dispatcher.clone(); + let service = self.service.clone(); tokio::spawn(async move { - let request_id = request.message().header().id(); - trace!( - "Calling service for request id {request_id}" - ); - let mut stream = svc.call(request).await; - let mut in_transaction = false; - - trace!("Awaiting service call results for request id {request_id}"); - while let Some(Ok(call_result)) = - stream.next().await - { - trace!("Processing service call result for request id {request_id}"); - let (response, feedback) = - call_result.into_inner(); - - if let Some(feedback) = feedback { - match feedback { - ServiceFeedback::Reconfigure { - idle_timeout, - } => { - if let Some(idle_timeout) = - idle_timeout - { - debug!( - "Reconfigured connection timeout to {idle_timeout:?}" - ); - let guard = config.load(); - let mut new_config = **guard; - new_config.idle_timeout = - idle_timeout; - config.store(Arc::new( - new_config, - )); - } - } - - ServiceFeedback::BeginTransaction => { - in_transaction = true; - } - - ServiceFeedback::EndTransaction => { - in_transaction = false; - } - } - } - - if let Some(mut response) = response { - loop { - match result_q_tx.try_send(response) { - Ok(()) => { - let pending_writes = - result_q_tx - .max_capacity() - - result_q_tx - .capacity(); - trace!("Queued message for sending: # pending writes={pending_writes}"); - metrics - .set_num_pending_writes( - pending_writes, - ); - break; - } - - Err(TrySendError::Closed(_)) => { - error!("Unable to queue message for sending: connection is shutting down."); - return; - } - - Err(TrySendError::Full( - unused_response, - )) => { - if in_transaction { - // Wait until there is space in the message queue. - tokio::task::yield_now() - .await; - response = - unused_response; - } else { - error!("Unable to queue message for sending: queue is full."); - return; - } - } - } - } - } - } - trace!("Finished processing service call results for request id {request_id}"); + dispatcher.dispatch(request, service, ()).await }); } } @@ -801,9 +743,10 @@ where impl Drop for Connection where - Buf: BufSource, - Buf::Output: Send + Sync + Unpin, - Svc: Service + Clone, + Buf: BufSource + Clone + Send + Sync, + Buf::Output: Octets + Send + Sync + Unpin, + Svc: Service + Clone + Send + Sync, + Svc::Target: Composer + Default + Send, { fn drop(&mut self) { if self.active { @@ -1065,3 +1008,147 @@ impl IdleTimer { self.reset_idle_timer() } } + +//------------ ServiceResponseHandler ----------------------------------------- + +/// Handles responses from the [`Service`] impl. +struct ServiceResponseHandler +where + RequestOctets: Octets + Send + Sync, + Svc: Service + Clone + Send + Sync + 'static, + Svc::Target: Composer + Default + Send, +{ + /// User supplied settings that influence our behaviour. + /// + /// May updated during request and response processing based on received + /// [`ServiceFeedback`]. + config: Arc>, + + /// The writer for pushing ready responses onto the queue waiting + /// to be written back the client. + result_q_tx: mpsc::Sender>>, + + /// [`ServerMetrics`] describing the status of the server. + metrics: Arc, + + /// The status of the service invoker. + status: InvokerStatus, +} + +impl ServiceResponseHandler +where + RequestOctets: Octets + Send + Sync, + Svc: Service + Clone + Send + Sync, + Svc::Target: Composer + Default + Send, +{ + /// Creates a new instance of the service response handler. + fn new( + config: Arc>, + result_q_tx: mpsc::Sender< + AdditionalBuilder>, + >, + metrics: Arc, + ) -> Self { + Self { + config, + result_q_tx, + metrics, + status: InvokerStatus::Normal, + } + } + + /// Apply changes to our configuration as requested by the [`Service`] + /// impl. + fn update_config(&self, idle_timeout: Option) { + if let Some(idle_timeout) = idle_timeout { + debug!("Reconfigured connection timeout to {idle_timeout:?}"); + let guard = self.config.load(); + let mut new_config = **guard; + new_config.idle_timeout = idle_timeout; + self.config.store(Arc::new(new_config)); + } + } + + /// Enqueue a response from the [`Service`] impl for writing back to the + /// client. + async fn do_enqueue_response( + &self, + mut response: AdditionalBuilder>, + ) { + loop { + match self.result_q_tx.try_send(response) { + Ok(()) => { + let pending_writes = self.result_q_tx.max_capacity() + - self.result_q_tx.capacity(); + trace!("Queued message for sending: # pending writes={pending_writes}"); + self.metrics.set_num_pending_writes(pending_writes); + break; + } + + Err(TrySendError::Closed(_)) => { + error!("Unable to queue message for sending: connection is shutting down."); + break; + } + + Err(TrySendError::Full(unused_response)) => { + if matches!(self.status, InvokerStatus::InTransaction) { + // Wait until there is space in the message queue. + tokio::task::yield_now().await; + response = unused_response; + } else { + error!("Unable to queue message for sending: queue is full."); + break; + } + } + } + } + } +} + +//--- Clone + +impl Clone for ServiceResponseHandler +where + RequestOctets: Octets + Send + Sync, + Svc: Service + Clone + Send + Sync + 'static, + Svc::Target: Composer + Default + Send, +{ + fn clone(&self) -> Self { + Self { + config: self.config.clone(), + result_q_tx: self.result_q_tx.clone(), + metrics: self.metrics.clone(), + status: InvokerStatus::Normal, + } + } +} + +//--- ServiceInvoker + +impl ServiceInvoker + for ServiceResponseHandler +where + RequestOctets: Octets + Send + Sync + 'static, + Svc: Service + Clone + Send + Sync, + Svc::Target: Composer + Default + Send, +{ + fn status(&self) -> InvokerStatus { + self.status + } + + fn set_status(&mut self, status: InvokerStatus) { + self.status = status; + } + + fn reconfigure(&self, idle_timeout: Option) { + self.update_config(idle_timeout); + } + + fn enqueue_response( + &self, + response: AdditionalBuilder>, + _meta: &(), + ) -> Pin + Send + '_>> { + Box::pin(async move { self.do_enqueue_response(response).await }) + } +} diff --git a/src/net/server/dgram.rs b/src/net/server/dgram.rs index 4da2e261c..489548091 100644 --- a/src/net/server/dgram.rs +++ b/src/net/server/dgram.rs @@ -11,9 +11,12 @@ //! [Datagram]: https://en.wikipedia.org/wiki/Datagram use core::fmt::Debug; use core::future::poll_fn; +use core::future::Future; use core::ops::Deref; +use core::pin::Pin; use core::time::Duration; +use std::boxed::Box; use std::io; use std::net::SocketAddr; use std::string::String; @@ -21,7 +24,6 @@ use std::string::ToString; use std::sync::{Arc, Mutex}; use arc_swap::ArcSwap; -use futures_util::stream::StreamExt; use log::{log_enabled, Level}; use octseq::Octets; use tokio::io::ReadBuf; @@ -33,18 +35,22 @@ use tokio::time::Instant; use tokio::time::MissedTickBehavior; use tracing::{error, trace, warn}; +use crate::base::iana::OptRcode; +use crate::base::message_builder::AdditionalBuilder; use crate::base::wire::Composer; -use crate::base::Message; +use crate::base::{Message, StreamTarget}; use crate::net::server::buf::BufSource; use crate::net::server::error::Error; use crate::net::server::message::Request; use crate::net::server::metrics::ServerMetrics; -use crate::net::server::service::{Service, ServiceFeedback}; +use crate::net::server::service::Service; use crate::net::server::sock::AsyncDgramSock; +use crate::net::server::util::mk_error_response; use crate::net::server::util::to_pcap_text; use crate::utils::config::DefMinMax; use super::buf::VecBufSource; +use super::invoker::{InvokerStatus, ServiceInvoker}; use super::message::{TransportSpecificContext, UdpTransportContext}; use super::ServerCommand; @@ -250,15 +256,11 @@ pub struct DgramServer where Sock: AsyncDgramSock + Send + Sync + 'static, Buf: BufSource + Send + Sync, - ::Output: Octets + Send + Sync + Unpin + 'static, - Svc: Clone - + Service<::Output, ()> - + Send - + Sync - + 'static, - ::Output, ()>>::Future: Send, - ::Output, ()>>::Stream: Send, - ::Output, ()>>::Target: Composer + Send, + Buf::Output: Octets + Send + Sync + Unpin + 'static, + Svc: Service + Clone + Send + Sync + 'static, + Svc::Future: Send, + Svc::Stream: Send, + Svc::Target: Composer + Default + Send, { /// The configuration of the server. config: Arc>, @@ -286,6 +288,9 @@ where /// [`ServerMetrics`] describing the status of the server. metrics: Arc, + + /// Dispatches requests to the service and enqueues responses for sending. + request_dispatcher: ServiceResponseHandler, } /// Creation @@ -294,11 +299,11 @@ impl DgramServer where Sock: AsyncDgramSock + Send + Sync, Buf: BufSource + Send + Sync, - ::Output: Octets + Send + Sync + Unpin, - Svc: Clone + Service<::Output, ()> + Send + Sync, - ::Output, ()>>::Future: Send, - ::Output, ()>>::Stream: Send, - ::Output, ()>>::Target: Composer + Send, + Buf::Output: Octets + Send + Sync + Unpin, + Svc: Service + Clone + Send + Sync + 'static, + Svc::Future: Send, + Svc::Stream: Send, + Svc::Target: Composer + Default + Send, { /// Constructs a new [`DgramServer`] with default configuration. /// @@ -332,15 +337,23 @@ where let command_tx = Arc::new(Mutex::new(command_tx)); let metrics = Arc::new(ServerMetrics::connection_less()); let config = Arc::new(ArcSwap::from_pointee(config)); + let sock = Arc::new(sock); + + let request_dispatcher = ServiceResponseHandler::new( + config.clone(), + sock.clone(), + metrics.clone(), + ); DgramServer { config, command_tx, command_rx, - sock: sock.into(), + sock, buf, service, metrics, + request_dispatcher, } } } @@ -351,11 +364,11 @@ impl DgramServer where Sock: AsyncDgramSock + Send + Sync, Buf: BufSource + Send + Sync, - ::Output: Octets + Send + Sync + Unpin, - Svc: Clone + Service<::Output, ()> + Send + Sync, - ::Output, ()>>::Future: Send, - ::Output, ()>>::Stream: Send, - ::Output, ()>>::Target: Composer + Send, + Buf::Output: Octets + Send + Sync + Unpin, + Svc: Service + Clone + Send + Sync + 'static, + Svc::Future: Send, + Svc::Stream: Send, + Svc::Target: Composer + Default + Send, { /// Get a reference to the network source being used to receive messages. #[must_use] @@ -376,15 +389,11 @@ impl DgramServer where Sock: AsyncDgramSock + Send + Sync + 'static, Buf: BufSource + Send + Sync, - ::Output: Octets + Send + Sync + 'static + Unpin, - Svc: Clone - + Service<::Output, ()> - + Send - + Sync - + 'static, - ::Output, ()>>::Future: Send, - ::Output, ()>>::Stream: Send, - ::Output, ()>>::Target: Composer + Send, + Buf::Output: Octets + Send + Sync + Unpin + 'static, + Svc: Service + Clone + Send + Sync + 'static, + Svc::Future: Send, + Svc::Stream: Send, + Svc::Target: Composer + Default + Send, { /// Start the server. /// @@ -464,11 +473,11 @@ impl DgramServer where Sock: AsyncDgramSock + Send + Sync, Buf: BufSource + Send + Sync, - ::Output: Octets + Send + Sync + Unpin, - Svc: Clone + Service<::Output, ()> + Send + Sync, - ::Output, ()>>::Future: Send, - ::Output, ()>>::Stream: Send, - ::Output, ()>>::Target: Composer + Send, + Buf::Output: Octets + Send + Sync + Unpin + 'static, + Svc: Service + Clone + Send + Sync + 'static, + Svc::Future: Send, + Svc::Stream: Send, + Svc::Target: Composer + Default + Send, { /// Receive incoming messages until shutdown or fatal error. async fn run_until_error(&self) -> Result<(), String> { @@ -492,93 +501,7 @@ where Err(err) => return Err(format!("Error while receiving message: {err}")), }; - let received_at = Instant::now(); - self.metrics.inc_num_received_requests(); - - if log_enabled!(Level::Trace) { - let pcap_text = to_pcap_text(&buf, bytes_read); - trace!(%addr, pcap_text, "Received message"); - } - - let svc = self.service.clone(); - let cfg = self.config.clone(); - let metrics = self.metrics.clone(); - let cloned_sock = self.sock.clone(); - let write_timeout = self.config.load().write_timeout; - - tokio::spawn(async move { - match Message::from_octets(buf) { - Err(err) => { - // TO DO: Count this event? - warn!("Failed while parsing request message: {err}"); - } - - // https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1 - // 4.1.1. Header section format - // "QR A one bit field that specifies whether - // this message is a query (0), or a - // response (1)." - Ok(msg) if msg.header().qr() => { - // TO DO: Count this event? - trace!("Ignoring received message because it is a reply, not a query."); - } - - Ok(msg) => { - let ctx = UdpTransportContext::new(cfg.load().max_response_size); - let ctx = TransportSpecificContext::Udp(ctx); - let request = Request::new(addr, received_at, msg, ctx, ()); - let mut stream = svc.call(request).await; - while let Some(Ok(call_result)) = stream.next().await { - let (response, feedback) = call_result.into_inner(); - - if let Some(feedback) = feedback { - match feedback { - ServiceFeedback::Reconfigure { - idle_timeout: _, // N/A - only applies to connection-oriented transports - } => { - // Nothing to do. - } - - ServiceFeedback::BeginTransaction|ServiceFeedback::EndTransaction => { - // Nothing to do. - } - } - } - - // Process the DNS response message, if any. - if let Some(response) = response { - // Convert the DNS response message into bytes. - let target = response.finish(); - let bytes = target.as_dgram_slice(); - - // Logging - if log_enabled!(Level::Trace) { - let pcap_text = to_pcap_text(bytes, bytes.len()); - trace!(%addr, pcap_text, "Sending response"); - } - - metrics.inc_num_pending_writes(); - - // Actually write the DNS response message bytes to the UDP - // socket. - if let Err(err) = Self::send_to( - &cloned_sock, - bytes, - &addr, - write_timeout, - ) - .await - { - warn!(%addr, "Failed to send response: {err}"); - } - - metrics.dec_num_pending_writes(); - metrics.inc_num_sent_responses(); - } - } - } - } - }); + self.process_received_message(buf, addr, bytes_read); } } } @@ -635,6 +558,66 @@ where Ok(()) } + fn process_received_message( + &self, + buf: ::Output, + addr: SocketAddr, + bytes_read: usize, + ) { + let received_at = Instant::now(); + self.metrics.inc_num_received_requests(); + + if log_enabled!(Level::Trace) { + let pcap_text = to_pcap_text(&buf, bytes_read); + trace!(%addr, pcap_text, "Received message"); + } + + match Message::from_octets(buf) { + Err(err) => { + // TO DO: Count this event? + warn!("Failed while parsing request message: {err}"); + // We can't send a response as we don't have a query ID to + // copy to the response. + } + + // https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1 + // 4.1.1. Header section format + // "QR A one bit field that specifies whether this message is + // a query (0), or a response (1)." + Ok(msg) if msg.header().qr() => { + // TO DO: Count this event? + trace!("Ignoring received message because it is a reply, not a query."); + let response = mk_error_response::( + &msg, + OptRcode::FORMERR, + ); + let dispatcher = self.request_dispatcher.clone(); + tokio::spawn(async move { + dispatcher.send_response(addr, response).await; + }); + } + + Ok(msg) => { + let ctx = UdpTransportContext::new( + self.config.load().max_response_size, + ); + let ctx = TransportSpecificContext::Udp(ctx); + let request = Request::new(addr, received_at, msg, ctx, ()); + + trace!( + "Spawning task to handle new message with id {}", + request.message().header().id() + ); + + let mut dispatcher = self.request_dispatcher.clone(); + let service = self.service.clone(); + tokio::spawn(async move { + dispatcher.dispatch(request, service, addr).await + }); + } + } + } + /// Receive a single datagram using the user supplied network socket. fn recv_from( &self, @@ -645,9 +628,101 @@ where .try_recv_buf_from(&mut buf) .map(|(bytes_read, addr)| (msg, addr, bytes_read)) } +} + +//--- Drop + +impl Drop for DgramServer +where + Sock: AsyncDgramSock + Send + Sync + 'static, + Buf: BufSource + Send + Sync, + Buf::Output: Octets + Send + Sync + Unpin + 'static, + Svc: Service + Clone + Send + Sync + 'static, + Svc::Future: Send, + Svc::Stream: Send, + Svc::Target: Composer + Default + Send, +{ + fn drop(&mut self) { + // Shutdown the DgramServer. Don't handle the failure case here as + // I'm not sure if it's safe to log or write to stderr from a Drop + // impl. + let _ = self.shutdown(); + } +} + +//------------ ServiceResponseHandler ----------------------------------------- + +/// Handles responses from the [`Service`] impl. +struct ServiceResponseHandler { + /// User supplied settings that influence our behaviour. + /// + /// May updated during request and response processing based on received + /// [`ServiceFeedback`]. + config: Arc>, + + /// The network socket to which responses will be sent. + sock: Arc, + + /// [`ServerMetrics`] describing the status of the server. + metrics: Arc, + + /// The status of the service invoker. + status: InvokerStatus, +} + +impl ServiceResponseHandler +where + Sock: AsyncDgramSock + Send + Sync + 'static, +{ + /// Creates a new instance of the service response handler. + fn new( + config: Arc>, + sock: Arc, + metrics: Arc, + ) -> Self { + Self { + config, + sock, + metrics, + status: InvokerStatus::Normal, + } + } + + /// Send a response from the [`Service`] impl to the client. + async fn send_response( + &self, + addr: SocketAddr, + response: AdditionalBuilder>, + ) { + // Convert the DNS response message into bytes. + let target = response.finish(); + let bytes = target.as_dgram_slice(); + + // Logging + if log_enabled!(Level::Trace) { + let pcap_text = to_pcap_text(bytes, bytes.len()); + trace!(%addr, pcap_text, "Sending {} bytes of response tp {addr}", bytes.len()); + } + + self.metrics.inc_num_pending_writes(); + + let write_timeout = self.config.load().write_timeout; + + // Actually write the DNS response message bytes to the UDP + // socket. + if let Err(err) = + Self::write_to_network(&self.sock, bytes, &addr, write_timeout) + .await + { + warn!(%addr, "Failed to send response: {err}"); + } + + self.metrics.dec_num_pending_writes(); + self.metrics.inc_num_sent_responses(); + } /// Send a single datagram using the user supplied network socket. - async fn send_to( + async fn write_to_network( sock: &Sock, data: &[u8], dest: &SocketAddr, @@ -671,26 +746,46 @@ where } } -//--- Drop +//--- Clone -impl Drop for DgramServer +impl Clone for ServiceResponseHandler { + fn clone(&self) -> Self { + Self { + config: self.config.clone(), + sock: self.sock.clone(), + metrics: self.metrics.clone(), + status: InvokerStatus::Normal, + } + } +} + +//--- ServiceInvoker + +impl ServiceInvoker + for ServiceResponseHandler where Sock: AsyncDgramSock + Send + Sync + 'static, - Buf: BufSource + Send + Sync, - ::Output: Octets + Send + Sync + Unpin + 'static, - Svc: Clone - + Service<::Output, ()> - + Send - + Sync - + 'static, - ::Output, ()>>::Future: Send, - ::Output, ()>>::Stream: Send, - ::Output, ()>>::Target: Composer + Send, + RequestOctets: Octets + Send + Sync + 'static, + Svc: Service + Clone + Send + Sync + 'static, + Svc::Target: Composer + Default + Send, { - fn drop(&mut self) { - // Shutdown the DgramServer. Don't handle the failure case here as - // I'm not sure if it's safe to log or write to stderr from a Drop - // impl. - let _ = self.shutdown(); + fn status(&self) -> InvokerStatus { + self.status + } + + fn set_status(&mut self, status: InvokerStatus) { + self.status = status; + } + + fn reconfigure(&self, _idle_timeout: Option) { + // N/A + } + + fn enqueue_response<'a>( + &'a self, + response: AdditionalBuilder>, + addr: &'a SocketAddr, + ) -> Pin + Send + 'a>> { + Box::pin(async move { self.send_response(*addr, response).await }) } } diff --git a/src/net/server/invoker.rs b/src/net/server/invoker.rs new file mode 100644 index 000000000..db6e706ba --- /dev/null +++ b/src/net/server/invoker.rs @@ -0,0 +1,188 @@ +/// Common service invoking logic for network servers. +/// +/// Used by [`stream::Connection`][net::server::stream::Connection] and +/// [`dgram::Dgram`][net::server::dgram::Dgram]. +use core::clone::Clone; +use core::default::Default; +use core::future::Future; +use core::pin::Pin; +use core::time::Duration; +use std::boxed::Box; + +use futures_util::StreamExt; +use octseq::Octets; +use tracing::trace; + +use crate::base::message_builder::AdditionalBuilder; +use crate::base::wire::Composer; +use crate::base::{Message, StreamTarget}; + +use super::message::Request; +use super::service::{Service, ServiceFeedback, ServiceResult}; +use super::util::mk_error_response; + +//------------ InvokerStatus -------------------------------------------------- + +/// The current status of the service invoker. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum InvokerStatus { + /// Processing independent responses. + Normal, + + /// Processing related responses. + InTransaction, + + /// No more responses to the current request will be processed. + Aborting, +} + +//------------ ServiceInvoker ------------------------------------------------- + +/// Dispatch requests to a [`Service`] and do common response processing. +/// +/// Response streams will be split into individual responses and passed to the +/// trait implementer for writing back to the network. +/// +/// If the [`Service`] impl returns a [`ServiceError`] a corresponding DNS +/// error response will be created and no further responses from the service +/// for the current request will be processed and the service response stream +/// will be dropped. +/// +/// Also handles [`ServiceFeedback`] by invoking fn impls on the trait +/// implementing type. +pub trait ServiceInvoker +where + Svc: Service + Send + Sync + 'static, + Svc::Target: Composer + Default, + RequestOctets: Octets + Send + Sync + 'static, + EnqueueMeta: Send + Sync + 'static, +{ + /// Dispatch a request and process the responses. + /// + /// Dispatches the given request to the given [`Service`] impl and + /// processes the stream of resulting responses, passing them to the trait + /// impl'd [`enqueue_response`] function with the provided metadata for + /// writing back to the network. until no more responses exist or the + /// trait impl'd [`status`] function reports that the state is + /// [`InvokerStatus::Aborting`]. + /// + /// On [`ServiceFeedback::Reconfigure`] passes the new configuration data + /// to the trait impl'd [`reconfugure`] function. + fn dispatch( + &mut self, + request: Request, + svc: Svc, + enqueue_meta: EnqueueMeta, + ) -> Pin + Send + '_>> + where + Self: Send + Sync, + Svc::Target: Send, + Svc::Stream: Send, + Svc::Future: Send, + { + Box::pin(async move { + let req_msg = request.message().clone(); + let request_id = request.message().header().id(); + + // Dispatch the request to the service for processing. + trace!("Calling service for request id {request_id}"); + let mut stream = svc.call(request).await; + + // Handle the resulting stream of responses, most likely just one as + // only XFR requests potentially result in multiple responses. + trace!( + "Awaiting service call results for request id {request_id}" + ); + while let Some(item) = stream.next().await { + trace!( + "Processing service call result for request id {request_id}" + ); + + let response = + self.process_response_stream_item(item, &req_msg); + + if let Some(response) = response { + self.enqueue_response(response, &enqueue_meta).await; + } + + if matches!(self.status(), InvokerStatus::Aborting) { + trace!("Aborting response stream processing for request id {request_id}"); + break; + } + } + trace!("Finished processing service call results for request id {request_id}"); + }) + } + + /// Processing a single response stream item. + /// + /// Calls [`process_feedback`] if necessary. Extracts any response for + /// further processing by the caller. + /// + /// On [`ServiceError`] calls the trait impl'd [`set_status`] function + /// with `InvokerStatus::Aborting` and returns a generated error response + /// instead of the response from the service. + fn process_response_stream_item( + &mut self, + stream_item: ServiceResult, + req_msg: &Message, + ) -> Option>> { + match stream_item { + Ok(call_result) => { + let (response, feedback) = call_result.into_inner(); + if let Some(feedback) = feedback { + self.process_feedback(feedback); + } + response + } + + Err(err) => { + self.set_status(InvokerStatus::Aborting); + Some(mk_error_response(req_msg, err.rcode().into())) + } + } + } + + //// Acts on [`ServiceFeedback`] received from the [`Service`]. + /// + /// Calls the trait impl'd [`reconfigure`] on + /// [`ServiceFeedback::Reconfigure`]. + /// + /// Calls the trait impl'd [`set_status`] on + /// [`ServiceFeedback::BeginTransaction`] with + /// [`InvokerStatus::InTransaction`]. + /// + /// Calls the trait impl'd [`set_status`] on + /// [`ServiceFeedback::EndTransaction`] with [`InvokerStatus::Normal`]. + fn process_feedback(&mut self, feedback: ServiceFeedback) { + match feedback { + ServiceFeedback::Reconfigure { idle_timeout } => { + self.reconfigure(idle_timeout); + } + + ServiceFeedback::BeginTransaction => { + self.set_status(InvokerStatus::InTransaction); + } + + ServiceFeedback::EndTransaction => { + self.set_status(InvokerStatus::Normal); + } + } + } + + /// Returns the current status of the service invoker. + fn status(&self) -> InvokerStatus; + + /// Sets the status of the service invoker to the given status. + fn set_status(&mut self, status: InvokerStatus); + + /// Reconfigures the network server with new settings. + fn reconfigure(&self, idle_timeout: Option); + + /// Enqueues a response for writing back to the client. + fn enqueue_response<'a>( + &'a self, + response: AdditionalBuilder>, + meta: &'a EnqueueMeta, + ) -> Pin + Send + 'a>>; +} diff --git a/src/net/server/mod.rs b/src/net/server/mod.rs index 0cf38e232..044ef0423 100644 --- a/src/net/server/mod.rs +++ b/src/net/server/mod.rs @@ -278,6 +278,7 @@ pub mod batcher; pub mod buf; pub mod dgram; pub mod error; +pub mod invoker; pub mod message; pub mod metrics; pub mod middleware; diff --git a/src/stelline/client.rs b/src/stelline/client.rs index 63e194b93..6e52c8528 100644 --- a/src/stelline/client.rs +++ b/src/stelline/client.rs @@ -1,4 +1,5 @@ #![allow(clippy::type_complexity)] +use core::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use core::ops::Deref; use std::boxed::Box; @@ -14,6 +15,7 @@ use std::vec::Vec; use bytes::Bytes; #[cfg(all(feature = "std", test))] use mock_instant::thread_local::MockClock; +use tokio::time::Instant; use tracing::{debug, info_span, trace}; use tracing_subscriber::EnvFilter; @@ -25,6 +27,9 @@ use crate::net::client::request::{ GetResponseMulti, RequestMessage, RequestMessageMulti, SendRequest, SendRequestMulti, }; +use crate::net::server::message::{ + Request, TransportSpecificContext, UdpTransportContext, +}; use crate::stelline::matches::match_multi_msg; use crate::zonefile::inplace::Entry::Record; @@ -127,6 +132,10 @@ pub async fn do_client_simple>>>( request: R, ) -> Result<(), StellineErrorCause> { let mut resp: Option> = None; + let mock_client_addr = + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)); + let mock_transport_ctx = + TransportSpecificContext::Udp(UdpTransportContext::new(None)); // Assume steps are in order. Maybe we need to define that. for step in &stelline.scenario.steps { @@ -156,7 +165,18 @@ pub async fn do_client_simple>>>( .entry .as_ref() .ok_or(StellineErrorCause::MissingStepEntry)?; - if !match_msg(entry, &answer, true) { + let client_addr = entry + .client_addr + .map(|ip| SocketAddr::new(ip, 0)) + .unwrap_or(mock_client_addr); + let req = Request::new( + client_addr, + Instant::now(), + answer, + mock_transport_ctx.clone(), + (), + ); + if !match_msg(entry, &req, true) { return Err(StellineErrorCause::MismatchedAnswer); } } @@ -459,6 +479,11 @@ pub async fn do_client<'a, T: ClientFactory>( MockClock::set_system_time(Duration::ZERO); } + let mock_client_addr = + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)); + let mock_transport_ctx = + TransportSpecificContext::Udp(UdpTransportContext::new(None)); + // Assume steps are in order. Maybe we need to define that. for step in &stelline.scenario.steps { let span = @@ -501,6 +526,11 @@ pub async fn do_client<'a, T: ClientFactory>( return Err(StellineErrorCause::MissingResponse); }; + let client_addr = entry + .client_addr + .map(|ip| SocketAddr::new(ip, 0)) + .unwrap_or(mock_client_addr); + if entry .matches .as_ref() @@ -563,12 +593,19 @@ pub async fn do_client<'a, T: ClientFactory>( trace!("Received answer."); trace!(?resp); + let req = Request::new( + client_addr, + Instant::now(), + resp, + mock_transport_ctx.clone(), + (), + ); let mut out_entry = Some(vec![]); match_multi_msg( &entry, 0, - &resp, + &req, true, &mut out_entry, ); @@ -642,8 +679,15 @@ pub async fn do_client<'a, T: ClientFactory>( trace!("Received answer."); trace!(?resp); + let req = Request::new( + client_addr, + Instant::now(), + resp, + mock_transport_ctx.clone(), + (), + ); if !match_multi_msg( - entry, idx, &resp, true, &mut None, + entry, idx, &req, true, &mut None, ) { return Err( StellineErrorCause::MismatchedAnswer, diff --git a/src/stelline/connection.rs b/src/stelline/connection.rs index df5556c0d..3724b44e6 100644 --- a/src/stelline/connection.rs +++ b/src/stelline/connection.rs @@ -1,17 +1,23 @@ +use core::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + use std::pin::Pin; use std::sync::Arc; use std::task::Waker; use std::task::{Context, Poll}; use std::vec::Vec; -use super::client::CurrStepValue; -use super::parse_stelline::Stelline; -use super::server::do_server; - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::time::Instant; use crate::base::message_builder::AdditionalBuilder; use crate::base::Message; +use crate::net::server::message::{ + NonUdpTransportContext, Request, TransportSpecificContext, +}; + +use super::client::CurrStepValue; +use super::parse_stelline::Stelline; +use super::server::do_server; #[derive(Debug)] pub struct Connection { @@ -20,7 +26,6 @@ pub struct Connection { waker: Option, reply: Option>>, send_body: bool, - tmpbuf: Vec, } @@ -85,10 +90,23 @@ impl AsyncWrite for Connection { } let msg = Message::from_octets(self.tmpbuf[2..].to_vec()).unwrap(); self.tmpbuf = Vec::new(); - let opt_reply = do_server(&msg, &self.stelline, &self.step_value); - if opt_reply.is_some() { + + let mock_client_addr = + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)); + let mock_transport_ctx = TransportSpecificContext::NonUdp( + NonUdpTransportContext::new(None), + ); + let req = Request::new( + mock_client_addr, + Instant::now(), + msg, + mock_transport_ctx.clone(), + (), + ); + + if let Some((opt_reply, _indices)) = do_server(&req, &self.stelline, &self.step_value) { // Do we need to support more than one reply? - self.reply = opt_reply; + self.reply = Some(opt_reply); let opt_waker = self.waker.take(); if let Some(waker) = opt_waker { waker.wake(); diff --git a/src/stelline/dgram.rs b/src/stelline/dgram.rs index f6f7ec68d..c4966c435 100644 --- a/src/stelline/dgram.rs +++ b/src/stelline/dgram.rs @@ -1,4 +1,6 @@ //! Provide server-side of datagram protocols +use core::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + use std::boxed::Box; use std::future::Future; use std::pin::Pin; @@ -8,12 +10,16 @@ use std::task::{Context, Poll, Waker}; use std::vec::Vec; use tokio::io::ReadBuf; +use tokio::time::Instant; use crate::base::message_builder::AdditionalBuilder; use crate::base::Message; use crate::net::client::protocol::{ AsyncConnect, AsyncDgramRecv, AsyncDgramSend, }; +use crate::net::server::message::{ + Request, TransportSpecificContext, UdpTransportContext, +}; use super::client::CurrStepValue; use super::parse_stelline::Stelline; @@ -97,12 +103,22 @@ impl AsyncDgramSend for DgramConnection { buf: &[u8], ) -> Poll> { let msg = Message::from_octets(buf).unwrap(); - let opt_reply = do_server(&msg, &self.stelline, &self.step_value); + let mock_client_addr = + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)); + let mock_transport_ctx = + TransportSpecificContext::Udp(UdpTransportContext::new(None)); + let req = Request::new( + mock_client_addr, + Instant::now(), + msg, + mock_transport_ctx.clone(), + (), + ); let len = buf.len(); - if opt_reply.is_some() { + if let Some((opt_reply, _indices)) = do_server(&req, &self.stelline, &self.step_value) { // Do we need to support more than one reply? let mut reply = self.reply.lock().unwrap(); - *reply = opt_reply; + *reply = Some(opt_reply); drop(reply); let mut waker = self.waker.lock().unwrap(); let opt_waker = (*waker).take(); diff --git a/src/stelline/matches.rs b/src/stelline/matches.rs index 3aa073f70..5198d5735 100644 --- a/src/stelline/matches.rs +++ b/src/stelline/matches.rs @@ -1,15 +1,18 @@ -use super::parse_stelline::{Entry, Matches, Question, Reply}; +use std::vec::Vec; + use crate::base::iana::{Opcode, OptRcode, Rtype}; use crate::base::opt::{Opt, OptRecord}; use crate::base::{Message, ParsedName, QuestionSection, RecordSection}; use crate::dep::octseq::Octets; +use crate::net::server::message::Request; use crate::rdata::ZoneRecordData; use crate::zonefile::inplace::Entry as ZonefileEntry; -use std::vec::Vec; -pub fn match_msg<'a, Octs: AsRef<[u8]> + Clone + Octets + 'a>( +use super::parse_stelline::{Entry, Matches, Question, Reply}; + +pub fn match_msg<'a, Octs: AsRef<[u8]> + Clone + Octets + 'a + Send + Sync>( entry: &Entry, - msg: &'a Message, + msg: &'a Request, verbose: bool, ) -> bool where @@ -18,16 +21,20 @@ where match_multi_msg(entry, 0, msg, verbose, &mut None) } -pub fn match_multi_msg<'a, Octs: AsRef<[u8]> + Clone + Octets + 'a>( +pub fn match_multi_msg< + 'a, + Octs: AsRef<[u8]> + Clone + Octets + 'a + Send + Sync, +>( entry: &Entry, idx: usize, - msg: &'a Message, + req: &'a Request, verbose: bool, out_answer: &mut Option>, ) -> bool where ::Range<'a>: Clone, { + let msg = req.message(); let sections = entry.sections.as_ref().unwrap(); let mut matches: Matches = match &entry.matches { @@ -261,7 +268,11 @@ where } } if matches.opcode { - let expected_opcode = if reply.notify { + // Test against default matches as that is what is used on mock queries + // and we don't want to require opcode NOTIFY on a reply. This first if + // check probably shoudln't even be here but instead the tests using + // REPLY NOTIFY should actually be using OPCODE NOTIFY. + let expected_opcode = if reply.notify && matches != Matches::default() { Opcode::NOTIFY } else if let Some(opcode) = entry.opcode { opcode @@ -308,16 +319,20 @@ where _ => { /* Okay */ } } } - if matches.tcp { - // Note: Creation of a TCP client is handled by the client factory passed to do_client(). - // TODO: Verify that the client is actually a TCP client. - } if matches.ttl { // Nothing to do. TTLs are checked in the relevant sections. } - if matches.udp { - // Note: Creation of a UDP client is handled by the client factory passed to do_client(). - // TODO: Verify that the client is actually a UDP client. + if matches.tcp && req.transport_ctx().is_udp() { + if verbose { + println!("Wrong transport type, expected TCP, got UDP"); + } + return false; + } + if matches.udp && req.transport_ctx().is_non_udp() { + if verbose { + println!("Wrong transport type, expected UDP, got non-UDP"); + } + return false; } // All checks passed! diff --git a/src/stelline/parse_stelline.rs b/src/stelline/parse_stelline.rs index 4df5f34b2..4a409c449 100644 --- a/src/stelline/parse_stelline.rs +++ b/src/stelline/parse_stelline.rs @@ -369,15 +369,74 @@ fn parse_entry>>( continue; } if token == MATCH { - entry.matches = Some(parse_match(tokens)); + let new_matches = parse_match(tokens); + match &mut entry.matches { + Some(matches) => { + matches.additional |= new_matches.additional; + matches.all |= new_matches.all; + matches.answer |= new_matches.answer; + matches.authority |= new_matches.authority; + matches.ad |= new_matches.ad; + matches.cd |= new_matches.cd; + matches.fl_do |= new_matches.fl_do; + matches.rd |= new_matches.rd; + matches.flags |= new_matches.flags; + matches.opcode |= new_matches.opcode; + matches.qname |= new_matches.qname; + matches.qtype |= new_matches.qtype; + matches.question |= new_matches.question; + matches.rcode |= new_matches.rcode; + matches.subdomain |= new_matches.subdomain; + matches.tcp |= new_matches.tcp; + matches.ttl |= new_matches.ttl; + matches.udp |= new_matches.udp; + matches.server_cookie |= new_matches.server_cookie; + matches.edns_data |= new_matches.edns_data; + matches.mock_client |= new_matches.mock_client; + matches.conn_closed |= new_matches.conn_closed; + matches.extra_packets |= new_matches.extra_packets; + matches.any_answer |= new_matches.any_answer; + } + None => entry.matches = Some(new_matches), + } continue; } if token == ADJUST { - entry.adjust = Some(parse_adjust(tokens)); + let new_adjust = parse_adjust(tokens); + match &mut entry.adjust { + Some(adjust) => { + adjust.copy_id |= new_adjust.copy_id; + adjust.copy_query |= new_adjust.copy_query; + } + None => entry.adjust = Some(new_adjust), + } continue; } if token == REPLY { - entry.reply = Some(parse_reply(tokens)); + let new_reply = parse_reply(tokens); + match &mut entry.reply { + Some(reply) => { + reply.aa |= new_reply.aa; + reply.ad |= new_reply.ad; + reply.cd |= new_reply.cd; + reply.fl_do |= new_reply.fl_do; + reply.qr |= new_reply.qr; + reply.ra |= new_reply.ra; + reply.rd |= new_reply.rd; + reply.tc |= new_reply.tc; + if new_reply.rcode.is_some() { + reply.rcode = new_reply.rcode; + } + reply.noerror |= new_reply.noerror; + reply.notimp |= new_reply.notimp; + reply.nxdomain |= new_reply.nxdomain; + reply.refused |= new_reply.refused; + reply.servfail |= new_reply.servfail; + reply.yxdomain |= new_reply.yxdomain; + reply.notify |= new_reply.notify; + } + None => entry.reply = Some(new_reply), + } continue; } if token == SECTION { @@ -550,7 +609,7 @@ fn parse_section>>( } } -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct Matches { pub additional: bool, pub all: bool, @@ -717,6 +776,11 @@ fn parse_reply(mut tokens: LineTokens<'_>) -> Reply { reply.rcode = Some(rcode); } else if token == "NOTIFY" { reply.notify = true; + } else if token == "QUERY" { + // We don't currently handle this anywhere yet as it's not clear + // what to do when this is specified. + } else if token == "NOTIMPL" { + reply.rcode = Some(OptRcode::NOTIMP); } else { println!("should handle reply {token:?}"); todo!(); diff --git a/src/stelline/server.rs b/src/stelline/server.rs index b9d06da9d..84c1f0975 100644 --- a/src/stelline/server.rs +++ b/src/stelline/server.rs @@ -8,6 +8,7 @@ use crate::base::message_builder::AdditionalBuilder; use crate::base::wire::Composer; use crate::base::{Message, MessageBuilder}; use crate::dep::octseq::Octets; +use crate::net::server::message::Request; use crate::zonefile::inplace::Entry as ZonefileEntry; use super::client::CurrStepValue; @@ -15,20 +16,31 @@ use super::matches::match_msg; use super::parse_stelline; use super::parse_stelline::{Adjust, Reply, Stelline}; +/// Gets a matching Stelline range entry. +/// +/// Entries inside a RANGE_BEGIN/RANGE_END block within a Stelline file define +/// queries to match and if matched the response to serve to that query. +/// +/// The _last_ matching entry is returned, as apparently that "works better if +/// the (Stelline) RPL is written with a recursive resolver in mind", along +/// with the zero based index of the range the entry was found in, and the +/// zero based index of the entry within that range. pub fn do_server<'a, Oct, Target>( - msg: &'a Message, + req: &'a Request, stelline: &Stelline, step_value: &CurrStepValue, -) -> Option> +) -> Option<(AdditionalBuilder, (usize, usize))> where ::Range<'a>: Clone, - Oct: Clone + Octets + 'a, + Oct: Clone + Octets + 'a + Send + Sync, Target: Composer + Default + OctetsBuilder + Truncate, ::AppendError: Debug, { let ranges = &stelline.scenario.ranges; let step = step_value.get(); let mut opt_entry = None; + let mut last_found_indices: Option<(usize, usize)> = None; + let msg = req.message(); // Take the last entry. That works better if the RPL is written with // a recursive resolver in mind. @@ -37,7 +49,7 @@ where msg.header().opcode(), msg.first_question().unwrap().qtype() ); - for range in ranges { + for (range_idx, range) in ranges.iter().enumerate() { trace!( "Checking against range {} <= {}", range.start_value, @@ -46,10 +58,11 @@ where if step < range.start_value || step > range.end_value { continue; } - for entry in &range.entry { - if match_msg(entry, msg, false) { + for (entry_idx, entry) in range.entry.iter().enumerate() { + if match_msg(entry, req, true) { trace!("Match found"); opt_entry = Some(entry); + last_found_indices = Some((range_idx, entry_idx)) } } } @@ -57,12 +70,12 @@ where match opt_entry { Some(entry) => { let reply = do_adjust(entry, msg); - Some(reply) + Some((reply, last_found_indices.unwrap())) } None => { trace!("No matching reply found"); println!("do_server: no reply at step value {step}"); - todo!(); + None } } }