diff --git a/README.md b/README.md index 1eb4a79..88a5fda 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,4 @@ -
-████████▄  ███▄▄▄▄      ▄████████      ████████▄     ▄████████  ▄██████▄     ▄███████▄    ▄███████▄    ▄████████    ▄████████
-███   ▀███ ███▀▀▀██▄   ███    ███      ███   ▀███   ███    ███ ███    ███   ███    ███   ███    ███   ███    ███   ███    ███
-███    ███ ███   ███   ███    █▀       ███    ███   ███    ███ ███    ███   ███    ███   ███    ███   ███    █▀    ███    ███
-███    ███ ███   ███   ███             ███    ███  ▄███▄▄▄▄██▀ ███    ███   ███    ███   ███    ███  ▄███▄▄▄      ▄███▄▄▄▄██▀
-███    ███ ███   ███ ▀███████████      ███    ███ ▀▀███▀▀▀▀▀   ███    ███ ▀█████████▀  ▀█████████▀  ▀▀███▀▀▀     ▀▀███▀▀▀▀▀   
-███    ███ ███   ███          ███      ███    ███ ▀███████████ ███    ███   ███          ███          ███    █▄  ▀███████████
-███   ▄███ ███   ███    ▄█    ███      ███   ▄███   ███    ███ ███    ███   ███          ███          ███    ███   ███    ███
-████████▀   ▀█   █▀   ▄████████▀       ████████▀    ███    ███  ▀██████▀   ▄████▀       ▄████▀        ██████████   ███    ███
-███    ███  
-
+DNSDropper ## What is it? DNSDropper is a tool for anyone looking for a light-weight dns proxy with filtering capabilities. Like blocking ads! It works by being a proxy in-between you and your normal DNS service, filtering any DNS requests for domains in your black list. @@ -24,7 +14,6 @@ DNSDropper uses in a single configuration file that is divided up by major featu You can also find examples of different configurations under the ```test/``` folder. - ## How to use (standard configuration) 1. Configure your ```server.yaml``` to fit your needs, and run the service. 1. To specify the config directory, use the ```--config``` or ```-c``` argument. e.g. ```dns_dropper --config config/myconfig.yaml``` diff --git a/config/internal.yaml b/config/internal.yaml index 5c68a50..43d5650 100644 --- a/config/internal.yaml +++ b/config/internal.yaml @@ -1,4 +1,4 @@ # Config constants internal systems default_server_config_dir: "config/server.yaml" max_udp_packet_size: 4096 -worker_thread_name: "WORKER" +worker_thread_name: "WT" diff --git a/docs/dns_dropper_banner01.png b/docs/dns_dropper_banner01.png new file mode 100644 index 0000000..85d8ac8 Binary files /dev/null and b/docs/dns_dropper_banner01.png differ diff --git a/src/filter.rs b/src/filter.rs index a1b6c6c..ec062bd 100644 --- a/src/filter.rs +++ b/src/filter.rs @@ -12,76 +12,87 @@ pub(crate) struct Filter { pub domain: String, } -pub(crate) fn should_filter(domain: &String, filter_list: &HashSet) -> bool { - for entry in filter_list { - if &entry.domain == domain { +impl Filter { + fn is_domain_matching(&self, in_domain: &String) -> bool { + return self.domain.eq(in_domain); + } +} + +pub(crate) fn should_filter( + domain: &String, + filter_list: &HashSet +) -> bool { + for filter in filter_list { + if filter.is_domain_matching(domain) { return true; } } return false; } -pub(crate) async fn load_block_list(block_list: &[Cow<'_, str>]) -> HashSet { +pub(crate) async fn load_filtered_domains( + block_list: &[Cow<'_, str>] +) -> HashSet { let mut complete_block_list: HashSet = HashSet::new(); for source in block_list { - match Url::parse(source) { - Ok(url) => { - if url.scheme().eq("file") { - match fs::read_to_string(source.clone().into_owned()) { - Ok(content) => { - parse_block_list_content(&mut complete_block_list, content) - .unwrap_or_else(|error| error!("Error occurred while trying to parse content from provided url resource: {} {}", source, error)); - } - Err(err) => { - error!("Error occurred while reading file '{}': {}", source, err); - } - }; - } else if url.scheme().eq("http") || url.scheme().eq("https") { - match reqwest::get(source.clone().into_owned()).await { - Ok(res) => { - trace!("Got response from block-list source: {}", source); - if res.status() == StatusCode::OK { - if let Ok(body) = res.text().await { - parse_block_list_content(&mut complete_block_list, body) - .unwrap_or_else(|error| error!("Error occurred while trying to parse content from provided url resource: {} {}", source, error)); - } - } - } - Err(err) => { - error!("Error occurred while requesting resource from '{}': {}", source, err); - } - }; + + if let Ok(url) = Url::parse(source) { + + if url.scheme().eq("file") { + + if let Ok(content) = fs::read_to_string(source.clone().into_owned()) { + parse_block_list_content(&mut complete_block_list, content).unwrap(); } - } - Err(_) => { - trace!("Provided string '{}' is not a URL, trying as an external file.", source); - match fs::File::open(source.clone().into_owned()) { - Ok(file) => { - let mut buf_reader = BufReader::new(file); - let mut body = String::new(); - match buf_reader.read_to_string(&mut body) { - Ok(_) => { - parse_block_list_content(&mut complete_block_list, body) - .unwrap_or_else(|error| error!("Error occurred while trying to parse content from provided local resource: {} {}", source, error)); - } - Err(err) => { - error!("Error occurred while reading file '{}': {}", source, err); - } + } else if url.scheme().eq("http") || url.scheme().eq("https") { + + if let Ok(res) = reqwest::get(source.clone().into_owned()).await { + trace!("Got response from block-list source: {}", source); + + if res.status() == StatusCode::OK { + + if let Ok(body) = res.text().await { + parse_block_list_content(&mut complete_block_list, body).unwrap(); } + + } else { + error!("Error! Response was: {}", res.status()); } - Err(err) => { - error!("Error occurred while reading file '{}': {}", source, err); - } + + } else { + error!("Error occurred while requesting resource from '{}'", source); + }; + } + + } else { + trace!("Provided string '{}' is not a URL, trying as an external file.", source); + + if let Ok(file) = fs::File::open(source.clone().into_owned()) { + let mut buf_reader = BufReader::new(file); + let mut body = String::new(); + + if let Ok(_) = buf_reader.read_to_string(&mut body) { + parse_block_list_content(&mut complete_block_list, body).unwrap(); + + } else if let Err(err) = buf_reader.read_to_string(&mut body) { + error!("Error occurred while reading file '{}': {}", source, err); } + + } else if let Err(err) = fs::File::open(source.clone().into_owned()) { + error!("Error occurred while reading file '{}': {}", source, err); } }; } + return complete_block_list; } -fn parse_block_list_content(complete_block_list: &mut HashSet, content: String) -> Result<()> { +fn parse_block_list_content( + complete_block_list: &mut HashSet, + content: String +) -> Result<()> { + let mut filter: Filter; for line in content.lines() { diff --git a/src/internal.rs b/src/internal.rs index e278c16..8706dd3 100644 --- a/src/internal.rs +++ b/src/internal.rs @@ -14,5 +14,5 @@ pub struct InternalConfig { pub const INTERNAL_CONFIG: InternalConfig = InternalConfig { default_server_config_dir: Cow::Borrowed("config/server.yaml"), max_udp_packet_size: 4096, - worker_thread_name: Cow::Borrowed("WORKER"), + worker_thread_name: Cow::Borrowed("WT"), }; diff --git a/src/logging.rs b/src/logging.rs index f52d00f..342a833 100644 --- a/src/logging.rs +++ b/src/logging.rs @@ -4,30 +4,28 @@ use log::Level; use std::io::Write; use crate::logging::HighlightStyle::{DebugHighlight, ErrorHighlight, InfoHighlight, TraceHighlight, WarnHighLight}; -const DARK_GREY_HIGHLIGHT: Style = Style::new() +pub const DARK_GREY_HIGHLIGHT: Style = Style::new() .fg_color(Some(Color::Ansi256(Ansi256Color(8)))); -const RED_HIGHLIGHT: Style = Style::new() +pub const RED_HIGHLIGHT: Style = Style::new() .fg_color(Some(Color::Ansi256(Ansi256Color(9)))); -const GREEN_HIGHLIGHT: Style = Style::new() +pub const GREEN_HIGHLIGHT: Style = Style::new() .fg_color(Some(Color::Ansi256(Ansi256Color(10)))); -const YELLOW_HIGHLIGHT: Style = Style::new() +pub const YELLOW_HIGHLIGHT: Style = Style::new() .fg_color(Some(Color::Ansi256(Ansi256Color(11)))); -const BLUE_HIGHLIGHT: Style = Style::new() +pub const BLUE_HIGHLIGHT: Style = Style::new() .fg_color(Some(Color::Ansi256(Ansi256Color(12)))); -const PURPLE_HIGHLIGHT: Style = Style::new() +pub const PURPLE_HIGHLIGHT: Style = Style::new() .fg_color(Some(Color::Ansi256(Ansi256Color(13)))); -const AQUA_HIGHLIGHT: Style = Style::new() +pub const AQUA_HIGHLIGHT: Style = Style::new() .fg_color(Some(Color::Ansi256(Ansi256Color(14)))); - -const DEFAULT_STYLE: Style = BLUE_HIGHLIGHT; -const TRACE_STYLE: Style = PURPLE_HIGHLIGHT.bold(); -const INFO_STYLE: Style = BLUE_HIGHLIGHT.bold(); -const ERROR_STYLE: Style = RED_HIGHLIGHT.bold(); -const DEBUG_STYLE: Style = GREEN_HIGHLIGHT.bold(); -const WARN_STYLE: Style = YELLOW_HIGHLIGHT.bold(); -const TIMESTAMP_STYLE: Style = DARK_GREY_HIGHLIGHT.underline(); -const THREAD_NAME_STYLE: Style = AQUA_HIGHLIGHT.bold(); -const MODULE_INFO_STYLE: Style = YELLOW_HIGHLIGHT.italic(); +pub const DEFAULT_STYLE: Style = BLUE_HIGHLIGHT; +pub const TRACE_STYLE: Style = PURPLE_HIGHLIGHT.bold(); +pub const INFO_STYLE: Style = BLUE_HIGHLIGHT.bold(); +pub const ERROR_STYLE: Style = RED_HIGHLIGHT.bold(); +pub const DEBUG_STYLE: Style = GREEN_HIGHLIGHT.bold(); +pub const WARN_STYLE: Style = YELLOW_HIGHLIGHT.bold(); +pub const TIMESTAMP_STYLE: Style = DARK_GREY_HIGHLIGHT.underline(); +pub const THREAD_NAME_STYLE: Style = AQUA_HIGHLIGHT.bold(); pub enum HighlightStyle { TraceHighlight, @@ -82,21 +80,18 @@ pub fn setup(level: &str) { } }; - - let ts = buf.timestamp_millis(); - let mod_path = record.module_path(); - let mod_line = record.line(); let lvl = record.level(); let args = record.args(); - writeln!(buf, "[{TIMESTAMP_STYLE}{}{TIMESTAMP_STYLE:#}][{THREAD_NAME_STYLE}{}{THREAD_NAME_STYLE:#}][{level_colour}{}{level_colour:#}][{MODULE_INFO_STYLE}{}.rs::{}{MODULE_INFO_STYLE:#}] {DEFAULT_STYLE}{}{DEFAULT_STYLE:#}", - ts, - thread::current().name().unwrap_or_default().to_ascii_uppercase(), - lvl, - mod_path.unwrap_or_default(), - mod_line.unwrap_or_default(), - args) + writeln!( + buf, + "[{TIMESTAMP_STYLE}{}{TIMESTAMP_STYLE:#}][{level_colour}{}{level_colour:#}][{THREAD_NAME_STYLE}{}{THREAD_NAME_STYLE:#}] {DEFAULT_STYLE}{}{DEFAULT_STYLE:#}", + ts, + lvl, + thread::current().name().unwrap_or_default().to_ascii_uppercase(), + args + ) }).init(); } diff --git a/src/server.rs b/src/server.rs index 1bbd7c6..69cbcd2 100644 --- a/src/server.rs +++ b/src/server.rs @@ -4,6 +4,8 @@ mod filter; mod logging; mod internal; +use crate::logging::WARN_STYLE; +use crate::logging::ERROR_STYLE; use std::borrow::Cow; use std::collections::{HashSet, VecDeque}; use std::env; @@ -11,7 +13,6 @@ use std::io::{Error, ErrorKind}; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; -use env_logger::fmt::style::Style; use tokio::io::Result; use tokio::net::UdpSocket; use tokio::time::timeout; @@ -19,13 +20,13 @@ use tokio::time::timeout; use crate::filter::Filter; use crate::dns::{BytePacketBuffer, DnsPacket, ResultCode}; use crate::internal::INTERNAL_CONFIG; -use crate::logging::{GetStyle, HighlightStyle}; -use crate::logging::HighlightStyle::ErrorHighlight; use crate::server_config::ServerConfig; fn main() -> Result<()> { let mut args: VecDeque = env::args().collect(); let mut config_dir: String = INTERNAL_CONFIG.default_server_config_dir.to_string(); + + // TODO - there is probably a crate for this already. use that. if args.len() > 1 { args.pop_front(); while !args.is_empty() { @@ -37,14 +38,12 @@ fn main() -> Result<()> { } } - - match ServerConfig::load_from(std::path::Path::new(&config_dir)) { Ok(server_config) => { logging::setup(server_config.logging.level.as_ref()); logging::print_title(); - start_server(&&server_config).expect("Failed to start server"); - Ok(()) + + start_server(server_config) } Err(_) => { @@ -54,24 +53,37 @@ fn main() -> Result<()> { } -fn start_server<'a>(config: &'a Cow<'a, ServerConfig>) -> Result<()> { +fn start_server<'a>( + config: Cow +) -> Result<()> { + let rt = tokio::runtime::Builder::new_multi_thread() .worker_threads(config.server.worker_thread_count as usize) - .thread_name_fn(||{ + .thread_name_fn(|| { static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); - format!("{}-{}",INTERNAL_CONFIG.worker_thread_name, id) + format!("{}-{}", INTERNAL_CONFIG.worker_thread_name, id) + }) + .on_thread_start(|| { + log::debug!("Starting worker thread."); + }) + .on_thread_stop(|| { + log::debug!("Stopping worker thread."); + }) + .on_thread_park(|| { + log::trace!("Parking worker thread."); + }) + .on_thread_unpark(|| { + log::trace!("Un-parking worker thread."); }) .enable_io() .enable_time() .build()?; rt.block_on(async { - let complete_block_list = filter::load_block_list(config.clone().udp_proxy.domain_block_lists.as_ref()).await; + let complete_block_list = filter::load_filtered_domains(config.clone().udp_proxy.domain_block_lists.as_ref()).await; let bind = config.udp_proxy.bind.to_string(); - log::info!("Creating server listening on bind: {}", bind); - let socket = UdpSocket::bind(bind).await?; let arc_socket = Arc::new(socket); let arc_config = Arc::new(config.clone().into_owned()); @@ -79,7 +91,7 @@ fn start_server<'a>(config: &'a Cow<'a, ServerConfig>) -> Result<()> { log::debug!("Block list contains '{}' different domain names.", complete_block_list.len()); let arc_block_list = Arc::new(complete_block_list); - + log::info!("Started DNS Proxy: {}", config.udp_proxy.bind); loop { match arc_socket.ready(tokio::io::Interest::READABLE).await { Ok(r) => { @@ -89,11 +101,13 @@ fn start_server<'a>(config: &'a Cow<'a, ServerConfig>) -> Result<()> { let arc_socket_clone = arc_socket.clone(); let arc_config_clone = arc_config.clone(); let arc_block_list_clone = arc_block_list.clone(); + tokio::spawn(async move { start_udp_dns_listener(arc_socket_clone, arc_config_clone, arc_block_list_clone).await; }); } } + Err(err) => { log::error!("Error trying to read from socket: {}", err); } @@ -102,71 +116,85 @@ fn start_server<'a>(config: &'a Cow<'a, ServerConfig>) -> Result<()> { }) } -async fn start_udp_dns_listener(socket: Arc, server_config: Arc, block_list: Arc>) { +async fn start_udp_dns_listener( + socket: Arc, + server_config: Arc, + block_list: Arc> +) { let mut req = BytePacketBuffer::new(); let (len, src) = match socket.try_recv_from(&mut req.buf) { - Ok(r) => r, + Ok(r) => { + log::trace!("DNS Listener received data of length: {} bytes", r.0); + r + }, + Err(ref e) if e.kind() == ErrorKind::WouldBlock => { return; } /* Lots of traffic produced when using loopback configuration on Windows. Why? I have no idea and I don't look forward to knowing. */ Err(ref e) if e.kind() == ErrorKind::ConnectionReset => { + log::trace!("Connection reset error occurred for local socket '{}'.", socket.local_addr().unwrap()); return; } Err(err) => { - log::error!("Failed to receive message from configured bind: {:?}", err.kind()); + log::error!("Failed to receive message from configured bind: {}", err); return; } }; if let Ok(mut packet) = DnsPacket::from_buffer(&mut req) { - match packet.questions.get(0) { - None => { - log::error!("Packet contains no questions"); - } - Some(query) => { - if filter::should_filter(&query.name, &block_list) { - let mut response_buffer = BytePacketBuffer::new(); - let style: Style = HighlightStyle::get_style(ErrorHighlight); - - let num = query.record_type.to_num(); - log::trace!("{style}BLOCK{style:#}: {}:{}", query.name, num); - - packet.header.result_code = ResultCode::NXDOMAIN; - match packet.write(&mut response_buffer) { - Ok(_) => { - match response_buffer.get_range(0, response_buffer.pos) { - Ok(data) => { - if let Err(err) = socket.send_to(&data, &src).await { - log::error!("Reply to '{}' failed {:?}", &src, err); - } - } - Err(err) => { - log::error!("Could not retrieve buffer range: {}", err); - } - } - } - Err(err) => { - log::error!("Error writing packet: {}", err); + log::trace!("Successfully constructed a DNS packet from received data."); + + if let None = packet.questions.get(0) { + log::error!("DNS Packet contains no questions"); + + } else if let Some(query) = packet.questions.get(0) { + log::trace!("Checking to see if the domain '{}' and/or record type '{}' should be filtered.", query.name, query.record_type.to_num()); + + if filter::should_filter(&query.name, &block_list) { + let mut response_buffer = BytePacketBuffer::new(); + log::info!("{ERROR_STYLE}!BLOCK!{ERROR_STYLE:#} -- {WARN_STYLE}{}:{}{WARN_STYLE:#}", query.name, query.record_type.to_num()); + packet.header.result_code = ResultCode::NXDOMAIN; + + if let Ok(_) = packet.write(&mut response_buffer) { + + if let Ok(data) = response_buffer.get_range(0, response_buffer.pos) { + + if let Err(err) = socket.send_to(&data, &src).await { + log::error!("Reply to '{}' failed {:?}", &src, err); } + + } else if let Err(err) = response_buffer.get_range(0, response_buffer.pos) { + log::error!("Could not retrieve buffer range: {}", err); } - } else { - for host in server_config.udp_proxy.dns_hosts.iter() { - match do_lookup(&req.buf[..len], host.to_string(), server_config.udp_proxy.timeout).await { - Ok(data) => { - if let Err(err) = socket.send_to(&data, &src).await { - log::error!("Replying to '{}' failed {:?}", &src, err); - continue; - } - return; - } - Err(err) => { - log::error!("Error processing request: {:?}", err); - } - }; - } + + } else if let Err(err) = packet.write(&mut response_buffer) { + log::error!("Error writing packet: {}", err); + } + + } else { + + for host in server_config.udp_proxy.dns_hosts.as_ref() { + + if let Ok(data) = do_lookup(&req.buf[..len], host.to_string(), server_config.udp_proxy.timeout).await { + + if let Err(err) = socket.send_to(&data, &src).await { + log::error!("Replying to '{}' failed {:?}", &src, err); + continue; + } + + log::debug!("Forwarded the answer for the domain '{}' from '{}'.", query.name, host); + return; + + } else if let Err(err) = do_lookup(&req.buf[..len], host.to_string(), server_config.udp_proxy.timeout).await { + log::error!("Error processing request: {:?}", err); + return; + + } else { + unreachable!() + }; } } } @@ -177,22 +205,27 @@ async fn do_lookup(buf: &[u8], remote_host: String, connection_timeout: i64) -> let duration = Duration::from_millis(connection_timeout as u64); let socket = UdpSocket::bind(("0.0.0.0", 0)).await?; - log::trace!("UDP socket bound to {:?} for: {:?}", socket.local_addr(), remote_host); + log::debug!("Outbound UDP socket bound to port '{}' for the host destination '{}'.", socket.local_addr().unwrap().port(), remote_host); let data: Result> = timeout(duration, async { socket.send_to(buf, remote_host.to_string()).await?; + let mut response = [0; INTERNAL_CONFIG.max_udp_packet_size as usize]; let length = socket.recv(&mut response).await?; + + log::debug!("Received response from '{}' for port '{}' with a length of: {} bytes", remote_host, socket.local_addr().unwrap().port(), length); + Ok(response[..length].to_vec()) }).await?; - match data { - Ok(data) => { - return Ok(data); - } - Err(err) => { - log::error!("Agent request to {:?} {:?}", remote_host, err); - } + if let Ok(data) = data { + return Ok(data) + + } else if let Err(err) = data { + log::error!("Agent request to {:?} {:?}", remote_host, err); + Err(err) + + } else { + unreachable!() } - Err(Error::new(ErrorKind::Other, "Proxy server failed to proxy request")) } diff --git a/test/config/server.yaml b/test/config/server.yaml index 83fbef1..8e8c416 100644 --- a/test/config/server.yaml +++ b/test/config/server.yaml @@ -1,9 +1,8 @@ server: - worker_thread_count: 4 + worker_thread_count: 3 udp_proxy: timeout: 2000 - packet_size: 512 bind: "127.0.0.1:53" dns_hosts: - "8.8.8.8:53"