From 8097b6916fc57bb4097a84f5c75feb56d8f306d8 Mon Sep 17 00:00:00 2001 From: Yujia Qiao Date: Sat, 25 Dec 2021 20:23:56 +0800 Subject: [PATCH] feat: hot-reload by restarting --- Cargo.lock | 120 ++++++++++++++++++++++++++++++++++++++ Cargo.toml | 1 + src/cli.rs | 2 +- src/config.rs | 20 +++---- src/config_watcher.rs | 119 +++++++++++++++++++++++++++++++++++++ src/lib.rs | 70 +++++++++++++++++++--- src/main.rs | 2 +- tests/common/mod.rs | 4 +- tests/integration_test.rs | 8 +-- 9 files changed, 319 insertions(+), 27 deletions(-) create mode 100644 src/config_watcher.rs diff --git a/Cargo.lock b/Cargo.lock index fdd4800..94bde51 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -272,6 +272,26 @@ dependencies = [ "libc", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db" +dependencies = [ + "cfg-if", + "lazy_static", +] + [[package]] name = "crypto-common" version = "0.1.1" @@ -342,6 +362,18 @@ dependencies = [ "libc", ] +[[package]] +name = "filetime" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "975ccf83d8d9d0d84682850a38c8169027be83368805971cc4f238c2b245bc98" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "winapi", +] + [[package]] name = "foreign-types" version = "0.3.2" @@ -357,6 +389,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +[[package]] +name = "fsevent-sys" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c0e564d24da983c053beff1bb7178e237501206840a3e6bf4e267b9e8ae734a" +dependencies = [ + "libc", +] + [[package]] name = "futures-core" version = "0.3.19" @@ -476,6 +517,26 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "inotify" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff" +dependencies = [ + "bitflags", + "inotify-sys", + "libc", +] + +[[package]] +name = "inotify-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" +dependencies = [ + "libc", +] + [[package]] name = "instant" version = "0.1.12" @@ -491,6 +552,26 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35" +[[package]] +name = "kqueue" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "058a107a784f8be94c7d35c1300f4facced2e93d2fbe5b1452b44e905ddca4a9" +dependencies = [ + "kqueue-sys", + "libc", +] + +[[package]] +name = "kqueue-sys" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8367585489f01bc55dd27404dcf56b95e6da061a256a666ab23be9ba96a2e587" +dependencies = [ + "bitflags", + "libc", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -576,6 +657,24 @@ dependencies = [ "tempfile", ] +[[package]] +name = "notify" +version = "5.0.0-pre.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "245d358380e2352c2d020e8ee62baac09b3420f1f6c012a31326cfced4ad487d" +dependencies = [ + "bitflags", + "crossbeam-channel", + "filetime", + "fsevent-sys", + "inotify", + "kqueue", + "libc", + "mio", + "walkdir", + "winapi", +] + [[package]] name = "ntapi" version = "0.3.6" @@ -874,6 +973,7 @@ dependencies = [ "fdlimit", "hex", "lazy_static", + "notify", "rand", "serde", "sha2 0.10.0", @@ -943,6 +1043,15 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.19" @@ -1389,6 +1498,17 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe" +[[package]] +name = "walkdir" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56" +dependencies = [ + "same-file", + "winapi", + "winapi-util", +] + [[package]] name = "wasi" version = "0.9.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index acf0343..c957b9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,3 +50,4 @@ tokio-native-tls = { version = "0.3.0", optional = true } async-trait = "0.1.52" snowstorm = { git = "https://github.com/black-binary/snowstorm", rev = "1887755", optional = true } base64 = { version = "0.13.0", optional = true } +notify = "5.0.0-pre.13" diff --git a/src/cli.rs b/src/cli.rs index 1364bed..4a3481a 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -6,7 +6,7 @@ pub enum KeypairType { X448, } -#[derive(Parser, Debug, Default)] +#[derive(Parser, Debug, Default, Clone)] #[clap(about, version, setting(AppSettings::DeriveDisplayOrder))] #[clap(group( ArgGroup::new("cmds") diff --git a/src/config.rs b/src/config.rs index 6bc4d5b..af09176 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use std::path::Path; use tokio::fs; -#[derive(Debug, Serialize, Deserialize, Copy, Clone)] +#[derive(Debug, Serialize, Deserialize, Copy, Clone, PartialEq)] pub enum TransportType { #[serde(rename = "tcp")] Tcp, @@ -20,7 +20,7 @@ impl Default for TransportType { } } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct ClientServiceConfig { #[serde(rename = "type", default = "default_service_type")] pub service_type: ServiceType, @@ -30,7 +30,7 @@ pub struct ClientServiceConfig { pub token: Option, } -#[derive(Debug, Serialize, Deserialize, Clone, Copy)] +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] pub enum ServiceType { #[serde(rename = "tcp")] Tcp, @@ -42,7 +42,7 @@ fn default_service_type() -> ServiceType { ServiceType::Tcp } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct ServerServiceConfig { #[serde(rename = "type", default = "default_service_type")] pub service_type: ServiceType, @@ -52,7 +52,7 @@ pub struct ServerServiceConfig { pub token: Option, } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub struct TlsConfig { pub hostname: Option, pub trusted_root: Option, @@ -64,7 +64,7 @@ fn default_noise_pattern() -> String { String::from("Noise_NK_25519_ChaChaPoly_BLAKE2s") } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct NoiseConfig { #[serde(default = "default_noise_pattern")] pub pattern: String, @@ -73,7 +73,7 @@ pub struct NoiseConfig { // TODO: Maybe psk can be added } -#[derive(Debug, Serialize, Deserialize, Default)] +#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)] pub struct TransportConfig { #[serde(rename = "type")] pub transport_type: TransportType, @@ -85,7 +85,7 @@ fn default_transport() -> TransportConfig { Default::default() } -#[derive(Debug, Serialize, Deserialize, Default)] +#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)] pub struct ClientConfig { pub remote_addr: String, pub default_token: Option, @@ -94,7 +94,7 @@ pub struct ClientConfig { pub transport: TransportConfig, } -#[derive(Debug, Serialize, Deserialize, Default)] +#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)] pub struct ServerConfig { pub bind_addr: String, pub default_token: Option, @@ -103,7 +103,7 @@ pub struct ServerConfig { pub transport: TransportConfig, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] #[serde(deny_unknown_fields)] pub struct Config { pub server: Option, diff --git a/src/config_watcher.rs b/src/config_watcher.rs new file mode 100644 index 0000000..2e262be --- /dev/null +++ b/src/config_watcher.rs @@ -0,0 +1,119 @@ +use crate::{ + config::{ClientServiceConfig, ServerServiceConfig}, + Config, +}; +use anyhow::{Context, Result}; +use notify::{EventKind, RecursiveMode, Watcher}; +use std::path::PathBuf; +use tokio::sync::{broadcast, mpsc}; +use tracing::{error, info, instrument}; + +#[derive(Debug)] +pub enum ConfigChangeEvent { + General(Config), // Trigger a full restart + ServiceChange(ServiceChangeEvent), +} + +#[derive(Debug)] +pub enum ServiceChangeEvent { + AddClientService(ClientServiceConfig), + DeleteClientService(ClientServiceConfig), + AddServerService(ServerServiceConfig), + DeleteServerService(ServerServiceConfig), +} + +pub struct ConfigWatcherHandle { + pub event_rx: mpsc::Receiver, +} + +impl ConfigWatcherHandle { + pub async fn new(path: &PathBuf, shutdown_rx: broadcast::Receiver) -> Result { + let (event_tx, event_rx) = mpsc::channel(16); + + let origin_cfg = Config::from_file(path).await?; + + tokio::spawn(config_watcher( + path.to_owned(), + shutdown_rx, + event_tx, + origin_cfg, + )); + + Ok(ConfigWatcherHandle { event_rx }) + } +} + +#[instrument(skip(shutdown_rx, cfg_event_tx))] +async fn config_watcher( + path: PathBuf, + mut shutdown_rx: broadcast::Receiver, + cfg_event_tx: mpsc::Sender, + mut old: Config, +) -> Result<()> { + let (fevent_tx, mut fevent_rx) = mpsc::channel(16); + + let mut watcher = notify::recommended_watcher(move |res| match res { + Ok(event) => { + let _ = fevent_tx.blocking_send(event); + } + Err(e) => error!("watch error: {:?}", e), + })?; + + // Initial start + cfg_event_tx + .send(ConfigChangeEvent::General(old.clone())) + .await + .unwrap(); + + watcher.watch(&path, RecursiveMode::NonRecursive)?; + info!("Start watching the config"); + + loop { + tokio::select! { + e = fevent_rx.recv() => { + match e { + Some(e) => { + match e.kind { + EventKind::Modify(_) => { + info!("Configuration modify event is detected"); + let new = match Config::from_file(&path).await.with_context(|| "The changed configuration is invalid. Ignored") { + Ok(v) => v, + Err(e) => { + error!("{:?}", e); + // If the config is invalid, just ignore it + continue; + } + }; + + for event in calculate_event(&old, &new) { + cfg_event_tx.send(event).await?; + } + + old = new; + }, + _ => (), // Just ignore other events + } + }, + None => break + } + }, + _ = shutdown_rx.recv() => break + } + } + + info!("Config watcher exiting"); + + Ok(()) +} + +fn calculate_event(old: &Config, new: &Config) -> Vec { + let mut ret = Vec::new(); + + if old == new { + return ret; + } + + ret.push(ConfigChangeEvent::General(new.to_owned())); + + ret +} diff --git a/src/lib.rs b/src/lib.rs index 1cf9299..3379d53 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ mod cli; mod config; +mod config_watcher; mod constants; mod helper; mod multi_map; @@ -9,10 +10,11 @@ mod transport; pub use cli::Cli; use cli::KeypairType; pub use config::Config; +use config_watcher::ServiceChangeEvent; pub use constants::UDP_BUFFER_SIZE; -use anyhow::{anyhow, Result}; -use tokio::sync::broadcast; +use anyhow::Result; +use tokio::sync::{broadcast, mpsc}; use tracing::debug; #[cfg(feature = "client")] @@ -25,6 +27,8 @@ mod server; #[cfg(feature = "server")] use server::run_server; +use crate::config_watcher::{ConfigChangeEvent, ConfigWatcherHandle}; + const DEFAULT_CURVE: KeypairType = KeypairType::X25519; fn get_str_from_keypair_type(curve: KeypairType) -> &'static str { @@ -56,20 +60,68 @@ fn genkey(curve: Option) -> Result<()> { crate::helper::feature_not_compile("nosie") } -pub async fn run(args: &Cli, shutdown_rx: broadcast::Receiver) -> Result<()> { +pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<()> { if args.genkey.is_some() { return genkey(args.genkey.unwrap()); } - let config = Config::from_file(args.config_path.as_ref().unwrap()).await?; - - debug!("{:?}", config); - // Raise `nofile` limit on linux and mac fdlimit::raise_fd_limit(); - match determine_run_mode(&config, args) { - RunMode::Undetermine => Err(anyhow!("Cannot determine running as a server or a client")), + // Spawn a config watcher. The watcher will send a initial signal to start the instance with a config + let config_path = args.config_path.as_ref().unwrap(); + let mut cfg_watcher = ConfigWatcherHandle::new(config_path, shutdown_rx).await?; + + // shutdown_tx owns the instance + let (shutdown_tx, _) = broadcast::channel(1); + + // (The join handle of the last instance, The service update channel sender) + let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender)> = + None; + + while let Some(e) = cfg_watcher.event_rx.recv().await { + match e { + ConfigChangeEvent::General(config) => { + match last_instance { + Some((i, _)) => { + shutdown_tx.send(true)?; + i.await??; + } + None => (), + } + + debug!("{:?}", config); + + let (service_update_tx, service_update_rx) = mpsc::channel(1024); + + last_instance = Some(( + tokio::spawn(run_instance( + config.clone(), + args.clone(), + shutdown_tx.subscribe(), + service_update_rx, + )), + service_update_tx, + )); + } + ConfigChangeEvent::ServiceChange(service_event) => { + if let Some((_, service_update_tx)) = &last_instance { + let _ = service_update_tx.send(service_event).await; + } + } + } + } + Ok(()) +} + +async fn run_instance( + config: Config, + args: Cli, + shutdown_rx: broadcast::Receiver, + _service_update: mpsc::Receiver, +) -> Result<()> { + match determine_run_mode(&config, &args) { + RunMode::Undetermine => panic!("Cannot determine running as a server or a client"), RunMode::Client => { #[cfg(not(feature = "client"))] crate::helper::feature_not_compile("client"); diff --git a/src/main.rs b/src/main.rs index 964951c..42ea03c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -29,5 +29,5 @@ async fn main() -> Result<()> { ) .init(); - run(&args, shutdown_rx).await + run(args, shutdown_rx).await } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 304dff9..9e59f92 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -20,7 +20,7 @@ pub async fn run_rathole_server( client: false, ..Default::default() }; - rathole::run(&cli, shutdown_rx).await + rathole::run(cli, shutdown_rx).await } pub async fn run_rathole_client( @@ -33,7 +33,7 @@ pub async fn run_rathole_client( client: true, ..Default::default() }; - rathole::run(&cli, shutdown_rx).await + rathole::run(cli, shutdown_rx).await } pub mod tcp { diff --git a/tests/integration_test.rs b/tests/integration_test.rs index f7afd4e..5cfc118 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -94,7 +94,7 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> { // Start the client info!("start the client"); let client = tokio::spawn(async move { - run_rathole_client(&config_path, client_shutdown_rx) + run_rathole_client(config_path, client_shutdown_rx) .await .unwrap(); }); @@ -105,7 +105,7 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> { // Start the server info!("start the server"); let server = tokio::spawn(async move { - run_rathole_server(&config_path, server_shutdown_rx) + run_rathole_server(config_path, server_shutdown_rx) .await .unwrap(); }); @@ -126,7 +126,7 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> { info!("restart the client"); let client_shutdown_rx = client_shutdown_tx.subscribe(); let client = tokio::spawn(async move { - run_rathole_client(&config_path, client_shutdown_rx) + run_rathole_client(config_path, client_shutdown_rx) .await .unwrap(); }); @@ -147,7 +147,7 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> { info!("restart the server"); let server_shutdown_rx = server_shutdown_tx.subscribe(); let server = tokio::spawn(async move { - run_rathole_server(&config_path, server_shutdown_rx) + run_rathole_server(config_path, server_shutdown_rx) .await .unwrap(); });