Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use tokio_util::sync::CancellationToken to quit service #371

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UdpSocket};
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use tokio::sync::{mpsc, oneshot, RwLock};
use tokio::time::{self, Duration, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span};

#[cfg(feature = "noise")]
Expand All @@ -33,7 +34,7 @@ use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE
// The entrypoint of running a client
pub async fn run_client(
config: Config,
shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
let config = config.client.ok_or_else(|| {
Expand All @@ -45,13 +46,13 @@ pub async fn run_client(
match config.transport.transport_type {
TransportType::Tcp => {
let mut client = Client::<TcpTransport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
client.run(cancel, update_rx).await
}
TransportType::Tls => {
#[cfg(any(feature = "native-tls", feature = "rustls"))]
{
let mut client = Client::<TlsTransport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
client.run(cancel, update_rx).await
}
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
crate::helper::feature_neither_compile("native-tls", "rustls")
Expand All @@ -60,7 +61,7 @@ pub async fn run_client(
#[cfg(feature = "noise")]
{
let mut client = Client::<NoiseTransport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
client.run(cancel, update_rx).await
}
#[cfg(not(feature = "noise"))]
crate::helper::feature_not_compile("noise")
Expand All @@ -69,7 +70,7 @@ pub async fn run_client(
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
{
let mut client = Client::<WebsocketTransport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
client.run(cancel, update_rx).await
}
#[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))]
crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls")
Expand Down Expand Up @@ -102,7 +103,7 @@ impl<T: 'static + Transport> Client<T> {
// The entrypoint of Client
async fn run(
&mut self,
mut shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
mut update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
for (name, config) in &self.config.services {
Expand All @@ -119,13 +120,7 @@ impl<T: 'static + Transport> Client<T> {
// Wait for the shutdown signal
loop {
tokio::select! {
val = shutdown_rx.recv() => {
match val {
Ok(_) => {}
Err(err) => {
error!("Unable to listen for shutdown signal: {}", err);
}
}
_ = cancel.cancelled() => {
break;
},
e = update_rx.recv() => {
Expand Down
8 changes: 4 additions & 4 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,9 @@ impl Config {
fn validate_server_config(server: &mut ServerConfig) -> Result<()> {
// Validate services
for (name, s) in &mut server.services {
s.name = name.clone();
s.name.clone_from(name);
if s.token.is_none() {
s.token = server.default_token.clone();
s.token.clone_from(&server.default_token);
if s.token.is_none() {
bail!("The token of service {} is not set", name);
}
Expand All @@ -272,9 +272,9 @@ impl Config {
fn validate_client_config(client: &mut ClientConfig) -> Result<()> {
// Validate services
for (name, s) in &mut client.services {
s.name = name.clone();
s.name.clone_from(name);
if s.token.is_none() {
s.token = client.default_token.clone();
s.token.clone_from(&client.default_token);
if s.token.is_none() {
bail!("The token of service {} is not set", name);
}
Expand Down
13 changes: 7 additions & 6 deletions src/config_watcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ use crate::{
Config,
};
use anyhow::{Context, Result};
use tokio_util::sync::CancellationToken;
use std::{
collections::HashMap,
env,
path::{Path, PathBuf},
};
use tokio::sync::{broadcast, mpsc};
use tokio::sync::mpsc;
use tracing::{error, info, instrument};

#[cfg(feature = "notify")]
Expand Down Expand Up @@ -98,7 +99,7 @@ pub struct ConfigWatcherHandle {
}

impl ConfigWatcherHandle {
pub async fn new(path: &Path, shutdown_rx: broadcast::Receiver<bool>) -> Result<Self> {
pub async fn new(path: &Path, cancel: CancellationToken) -> Result<Self> {
let (event_tx, event_rx) = mpsc::unbounded_channel();
let origin_cfg = Config::from_file(path).await?;

Expand All @@ -109,7 +110,7 @@ impl ConfigWatcherHandle {

tokio::spawn(config_watcher(
path.to_owned(),
shutdown_rx,
cancel,
event_tx,
origin_cfg,
));
Expand All @@ -132,10 +133,10 @@ async fn config_watcher(
}

#[cfg(feature = "notify")]
#[instrument(skip(shutdown_rx, event_tx, old))]
#[instrument(skip(cancel, event_tx, old))]
async fn config_watcher(
path: PathBuf,
mut shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
event_tx: mpsc::UnboundedSender<ConfigChange>,
mut old: Config,
) -> Result<()> {
Expand Down Expand Up @@ -190,7 +191,7 @@ async fn config_watcher(
None => break
}
},
_ = shutdown_rx.recv() => break
_ = cancel.cancelled() => break
}
}

Expand Down
22 changes: 11 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ pub use config::Config;
pub use constants::UDP_BUFFER_SIZE;

use anyhow::Result;
use tokio::sync::{broadcast, mpsc};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info};

#[cfg(feature = "client")]
Expand Down Expand Up @@ -59,7 +60,7 @@ fn genkey(curve: Option<KeypairType>) -> Result<()> {
crate::helper::feature_not_compile("nosie")
}

pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
pub async fn run(args: Cli, cancel: CancellationToken) -> Result<()> {
if args.genkey.is_some() {
return genkey(args.genkey.unwrap());
}
Expand All @@ -69,10 +70,9 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()

// Spawn a config watcher. The watcher will send a initial signal to start the instance with a config
let config_path = args.config_path.as_ref().unwrap();
let mut cfg_watcher = ConfigWatcherHandle::new(config_path, shutdown_rx).await?;
let mut cfg_watcher = ConfigWatcherHandle::new(config_path, cancel).await?;

// shutdown_tx owns the instance
let (shutdown_tx, _) = broadcast::channel(1);
let local_cancel_tx = CancellationToken::new();

// (The join handle of the last instance, The service update channel sender)
let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ConfigChange>)> = None;
Expand All @@ -82,7 +82,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
ConfigChange::General(config) => {
if let Some((i, _)) = last_instance {
info!("General configuration change detected. Restarting...");
shutdown_tx.send(true)?;
local_cancel_tx.cancel();
i.await??;
}

Expand All @@ -94,7 +94,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
tokio::spawn(run_instance(
*config,
args.clone(),
shutdown_tx.subscribe(),
local_cancel_tx.clone(),
service_update_rx,
)),
service_update_tx,
Expand All @@ -109,15 +109,15 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()
}
}

let _ = shutdown_tx.send(true);
local_cancel_tx.cancel();

Ok(())
}

async fn run_instance(
config: Config,
args: Cli,
shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
service_update: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
match determine_run_mode(&config, &args) {
Expand All @@ -126,13 +126,13 @@ async fn run_instance(
#[cfg(not(feature = "client"))]
crate::helper::feature_not_compile("client");
#[cfg(feature = "client")]
run_client(config, shutdown_rx, service_update).await
run_client(config, cancel, service_update).await
}
RunMode::Server => {
#[cfg(not(feature = "server"))]
crate::helper::feature_not_compile("server");
#[cfg(feature = "server")]
run_server(config, shutdown_rx, service_update).await
run_server(config, cancel, service_update).await
}
}
}
Expand Down
14 changes: 6 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
use anyhow::Result;
use clap::Parser;
use rathole::{run, Cli};
use tokio::{signal, sync::broadcast};
use tokio::signal;
use tracing_subscriber::EnvFilter;
use tokio_util::sync::CancellationToken;

#[tokio::main]
async fn main() -> Result<()> {
let args = Cli::parse();

let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
let cancel_tx = CancellationToken::new();
let cancel_rx = cancel_tx.clone();
tokio::spawn(async move {
if let Err(e) = signal::ctrl_c().await {
// Something really weird happened. So just panic
panic!("Failed to listen for the ctrl-c signal: {:?}", e);
}

if let Err(e) = shutdown_tx.send(true) {
// shutdown signal must be catched and handle properly
// `rx` must not be dropped
panic!("Failed to send shutdown signal: {:?}", e);
}
cancel_tx.cancel(); // synchronously
});

#[cfg(feature = "console")]
Expand All @@ -41,5 +39,5 @@ async fn main() -> Result<()> {
.init();
}

run(args, shutdown_rx).await
run(args, cancel_rx).await
}
15 changes: 8 additions & 7 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use backoff::backoff::Backoff;
use backoff::ExponentialBackoff;

use rand::RngCore;
use tokio_util::sync::CancellationToken;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
Expand Down Expand Up @@ -41,7 +42,7 @@ const HANDSHAKE_TIMEOUT: u64 = 5; // Timeout for transport handshake
// The entrypoint of running a server
pub async fn run_server(
config: Config,
shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
let config = match config.server {
Expand All @@ -54,13 +55,13 @@ pub async fn run_server(
match config.transport.transport_type {
TransportType::Tcp => {
let mut server = Server::<TcpTransport>::from(config).await?;
server.run(shutdown_rx, update_rx).await?;
server.run(cancel, update_rx).await?;
}
TransportType::Tls => {
#[cfg(any(feature = "native-tls", feature = "rustls"))]
{
let mut server = Server::<TlsTransport>::from(config).await?;
server.run(shutdown_rx, update_rx).await?;
server.run(cancel, update_rx).await?;
}
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
crate::helper::feature_neither_compile("native-tls", "rustls")
Expand All @@ -69,7 +70,7 @@ pub async fn run_server(
#[cfg(feature = "noise")]
{
let mut server = Server::<NoiseTransport>::from(config).await?;
server.run(shutdown_rx, update_rx).await?;
server.run(cancel, update_rx).await?;
}
#[cfg(not(feature = "noise"))]
crate::helper::feature_not_compile("noise")
Expand All @@ -78,7 +79,7 @@ pub async fn run_server(
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
{
let mut server = Server::<WebsocketTransport>::from(config).await?;
server.run(shutdown_rx, update_rx).await?;
server.run(cancel, update_rx).await?;
}
#[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))]
crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls")
Expand Down Expand Up @@ -134,7 +135,7 @@ impl<T: 'static + Transport> Server<T> {
// The entry point of Server
pub async fn run(
&mut self,
mut shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
mut update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
// Listen at `server.bind_addr`
Expand Down Expand Up @@ -205,7 +206,7 @@ impl<T: 'static + Transport> Server<T> {
}
},
// Wait for the shutdown signal
_ = shutdown_rx.recv() => {
_ = cancel.cancelled() => {
info!("Shuting down gracefully...");
break;
},
Expand Down
10 changes: 5 additions & 5 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,36 @@ use anyhow::Result;
use tokio::{
io::{self, AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream, ToSocketAddrs},
sync::broadcast,
};
use tokio_util::sync::CancellationToken;

pub const PING: &str = "ping";
pub const PONG: &str = "pong";

pub async fn run_rathole_server(
config_path: &str,
shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
) -> Result<()> {
let cli = rathole::Cli {
config_path: Some(PathBuf::from(config_path)),
server: true,
client: false,
..Default::default()
};
rathole::run(cli, shutdown_rx).await
rathole::run(cli, cancel).await
}

pub async fn run_rathole_client(
config_path: &str,
shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
) -> Result<()> {
let cli = rathole::Cli {
config_path: Some(PathBuf::from(config_path)),
server: false,
client: true,
..Default::default()
};
rathole::run(cli, shutdown_rx).await
rathole::run(cli, cancel).await
}

pub mod tcp {
Expand Down
Loading
Loading