Skip to content

Commit

Permalink
Use tokio.sync::CancellationToken to quit service
Browse files Browse the repository at this point in the history
  • Loading branch information
sword-jin committed Jun 17, 2024
1 parent be14d12 commit b96d4e4
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 34 deletions.
13 changes: 7 additions & 6 deletions src/config_watcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ use crate::{
Config,
};
use anyhow::{Context, Result};
use tokio_util::sync::CancellationToken;
use std::{
collections::HashMap,
env,
path::{Path, PathBuf},
};
use tokio::sync::{broadcast, mpsc};
use tokio::sync::mpsc;
use tracing::{error, info, instrument};

#[cfg(feature = "notify")]
Expand Down Expand Up @@ -98,7 +99,7 @@ pub struct ConfigWatcherHandle {
}

impl ConfigWatcherHandle {
pub async fn new(path: &Path, shutdown_rx: broadcast::Receiver<bool>) -> Result<Self> {
pub async fn new(path: &Path, cancel: CancellationToken) -> Result<Self> {
let (event_tx, event_rx) = mpsc::unbounded_channel();
let origin_cfg = Config::from_file(path).await?;

Expand All @@ -109,7 +110,7 @@ impl ConfigWatcherHandle {

tokio::spawn(config_watcher(
path.to_owned(),
shutdown_rx,
cancel,
event_tx,
origin_cfg,
));
Expand All @@ -132,10 +133,10 @@ async fn config_watcher(
}

#[cfg(feature = "notify")]
#[instrument(skip(shutdown_rx, event_tx, old))]
#[instrument(skip(cancel, event_tx, old))]
async fn config_watcher(
path: PathBuf,
mut shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
event_tx: mpsc::UnboundedSender<ConfigChange>,
mut old: Config,
) -> Result<()> {
Expand Down Expand Up @@ -190,7 +191,7 @@ async fn config_watcher(
None => break
}
},
_ = shutdown_rx.recv() => break
_ = cancel.cancelled() => break
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub use constants::UDP_BUFFER_SIZE;

use anyhow::Result;
use tokio::sync::{broadcast, mpsc};
use tokio_util::sync::CancellationToken;
use tracing::{debug, info};

#[cfg(feature = "client")]
Expand Down Expand Up @@ -59,7 +60,7 @@ fn genkey(curve: Option<KeypairType>) -> Result<()> {
crate::helper::feature_not_compile("nosie")
}

pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
pub async fn run(args: Cli, cancel: CancellationToken) -> Result<()> {
if args.genkey.is_some() {
return genkey(args.genkey.unwrap());
}
Expand All @@ -69,7 +70,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()

// Spawn a config watcher. The watcher will send a initial signal to start the instance with a config
let config_path = args.config_path.as_ref().unwrap();
let mut cfg_watcher = ConfigWatcherHandle::new(config_path, shutdown_rx).await?;
let mut cfg_watcher = ConfigWatcherHandle::new(config_path, cancel).await?;

// shutdown_tx owns the instance
let (shutdown_tx, _) = broadcast::channel(1);
Expand Down
14 changes: 6 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
use anyhow::Result;
use clap::Parser;
use rathole::{run, Cli};
use tokio::{signal, sync::broadcast};
use tokio::signal;
use tracing_subscriber::EnvFilter;
use tokio_util::sync::CancellationToken;

#[tokio::main]
async fn main() -> Result<()> {
let args = Cli::parse();

let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
let cancel_sender = CancellationToken::new();
let cancel_reader = cancel_sender.clone();
tokio::spawn(async move {
if let Err(e) = signal::ctrl_c().await {
// Something really weird happened. So just panic
panic!("Failed to listen for the ctrl-c signal: {:?}", e);
}

if let Err(e) = shutdown_tx.send(true) {
// shutdown signal must be catched and handle properly
// `rx` must not be dropped
panic!("Failed to send shutdown signal: {:?}", e);
}
cancel_sender.cancel(); // synchronously
});

#[cfg(feature = "console")]
Expand All @@ -41,5 +39,5 @@ async fn main() -> Result<()> {
.init();
}

run(args, shutdown_rx).await
run(args, cancel_reader).await
}
10 changes: 5 additions & 5 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,36 @@ use anyhow::Result;
use tokio::{
io::{self, AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream, ToSocketAddrs},
sync::broadcast,
};
use tokio_util::sync::CancellationToken;

pub const PING: &str = "ping";
pub const PONG: &str = "pong";

pub async fn run_rathole_server(
config_path: &str,
shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
) -> Result<()> {
let cli = rathole::Cli {
config_path: Some(PathBuf::from(config_path)),
server: true,
client: false,
..Default::default()
};
rathole::run(cli, shutdown_rx).await
rathole::run(cli, cancel).await
}

pub async fn run_rathole_client(
config_path: &str,
shutdown_rx: broadcast::Receiver<bool>,
cancel: CancellationToken,
) -> Result<()> {
let cli = rathole::Cli {
config_path: Some(PathBuf::from(config_path)),
server: false,
client: true,
..Default::default()
};
rathole::run(cli, shutdown_rx).await
rathole::run(cli, cancel).await
}

pub mod tcp {
Expand Down
29 changes: 16 additions & 13 deletions tests/integration_test.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use anyhow::{Ok, Result};
use common::{run_rathole_client, PING, PONG};
use rand::Rng;
use tokio_util::sync::CancellationToken;
use std::time::Duration;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpStream, UdpSocket},
sync::broadcast,
time,
};
use tracing::{debug, info, instrument};
Expand Down Expand Up @@ -117,13 +117,13 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> {
return Ok(());
}

let (client_shutdown_tx, client_shutdown_rx) = broadcast::channel(1);
let (server_shutdown_tx, server_shutdown_rx) = broadcast::channel(1);
let (cancel_client_tx, cancel_server_tx) = (CancellationToken::new(), CancellationToken::new());

// Start the client
info!("start the client");
let cancel_client_rx = cancel_client_tx.clone();
let client = tokio::spawn(async move {
run_rathole_client(config_path, client_shutdown_rx)
run_rathole_client(config_path, cancel_client_rx)
.await
.unwrap();
});
Expand All @@ -133,8 +133,9 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> {

// Start the server
info!("start the server");
let cancel_server_rx = cancel_server_tx.clone();
let server = tokio::spawn(async move {
run_rathole_server(config_path, server_shutdown_rx)
run_rathole_server(config_path, cancel_server_rx)
.await
.unwrap();
});
Expand All @@ -149,13 +150,14 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> {

// Simulate the client crash and restart
info!("shutdown the client");
client_shutdown_tx.send(true)?;
cancel_client_tx.cancel();
let _ = tokio::join!(client);

info!("restart the client");
let client_shutdown_rx = client_shutdown_tx.subscribe();
let restart_client_cancel_tx = CancellationToken::new();
let restart_client_cancel_rx = restart_client_cancel_tx.clone();
let client = tokio::spawn(async move {
run_rathole_client(config_path, client_shutdown_rx)
run_rathole_client(config_path, restart_client_cancel_rx)
.await
.unwrap();
});
Expand All @@ -170,13 +172,14 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> {

// Simulate the server crash and restart
info!("shutdown the server");
server_shutdown_tx.send(true)?;
cancel_server_tx.cancel();
let _ = tokio::join!(server);

info!("restart the server");
let server_shutdown_rx = server_shutdown_tx.subscribe();
let restart_server_cancel_tx = CancellationToken::new();
let restart_server_cancel_rx = restart_server_cancel_tx.clone();
let server = tokio::spawn(async move {
run_rathole_server(config_path, server_shutdown_rx)
run_rathole_server(config_path, restart_server_cancel_rx)
.await
.unwrap();
});
Expand Down Expand Up @@ -205,8 +208,8 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> {

// Shutdown
info!("shutdown the server and the client");
server_shutdown_tx.send(true)?;
client_shutdown_tx.send(true)?;
restart_client_cancel_tx.cancel();
restart_server_cancel_tx.cancel();

let _ = tokio::join!(server, client);

Expand Down

0 comments on commit b96d4e4

Please sign in to comment.