feat: hot-reload by restarting

This commit is contained in:
Yujia Qiao 2021-12-25 20:23:56 +08:00 committed by Yujia Qiao
parent 24959daa93
commit 8097b6916f
9 changed files with 319 additions and 27 deletions

120
Cargo.lock generated
View File

@ -272,6 +272,26 @@ dependencies = [
"libc",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db"
dependencies = [
"cfg-if",
"lazy_static",
]
[[package]]
name = "crypto-common"
version = "0.1.1"
@ -342,6 +362,18 @@ dependencies = [
"libc",
]
[[package]]
name = "filetime"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "975ccf83d8d9d0d84682850a38c8169027be83368805971cc4f238c2b245bc98"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"winapi",
]
[[package]]
name = "foreign-types"
version = "0.3.2"
@ -357,6 +389,15 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "fsevent-sys"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c0e564d24da983c053beff1bb7178e237501206840a3e6bf4e267b9e8ae734a"
dependencies = [
"libc",
]
[[package]]
name = "futures-core"
version = "0.3.19"
@ -476,6 +517,26 @@ dependencies = [
"hashbrown",
]
[[package]]
name = "inotify"
version = "0.9.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff"
dependencies = [
"bitflags",
"inotify-sys",
"libc",
]
[[package]]
name = "inotify-sys"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb"
dependencies = [
"libc",
]
[[package]]
name = "instant"
version = "0.1.12"
@ -491,6 +552,26 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35"
[[package]]
name = "kqueue"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "058a107a784f8be94c7d35c1300f4facced2e93d2fbe5b1452b44e905ddca4a9"
dependencies = [
"kqueue-sys",
"libc",
]
[[package]]
name = "kqueue-sys"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8367585489f01bc55dd27404dcf56b95e6da061a256a666ab23be9ba96a2e587"
dependencies = [
"bitflags",
"libc",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
@ -576,6 +657,24 @@ dependencies = [
"tempfile",
]
[[package]]
name = "notify"
version = "5.0.0-pre.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "245d358380e2352c2d020e8ee62baac09b3420f1f6c012a31326cfced4ad487d"
dependencies = [
"bitflags",
"crossbeam-channel",
"filetime",
"fsevent-sys",
"inotify",
"kqueue",
"libc",
"mio",
"walkdir",
"winapi",
]
[[package]]
name = "ntapi"
version = "0.3.6"
@ -874,6 +973,7 @@ dependencies = [
"fdlimit",
"hex",
"lazy_static",
"notify",
"rand",
"serde",
"sha2 0.10.0",
@ -943,6 +1043,15 @@ version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.19"
@ -1389,6 +1498,17 @@ version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe"
[[package]]
name = "walkdir"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56"
dependencies = [
"same-file",
"winapi",
"winapi-util",
]
[[package]]
name = "wasi"
version = "0.9.0+wasi-snapshot-preview1"

View File

@ -50,3 +50,4 @@ tokio-native-tls = { version = "0.3.0", optional = true }
async-trait = "0.1.52"
snowstorm = { git = "https://github.com/black-binary/snowstorm", rev = "1887755", optional = true }
base64 = { version = "0.13.0", optional = true }
notify = "5.0.0-pre.13"

View File

@ -6,7 +6,7 @@ pub enum KeypairType {
X448,
}
#[derive(Parser, Debug, Default)]
#[derive(Parser, Debug, Default, Clone)]
#[clap(about, version, setting(AppSettings::DeriveDisplayOrder))]
#[clap(group(
ArgGroup::new("cmds")

View File

@ -4,7 +4,7 @@ use std::collections::HashMap;
use std::path::Path;
use tokio::fs;
#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
#[derive(Debug, Serialize, Deserialize, Copy, Clone, PartialEq)]
pub enum TransportType {
#[serde(rename = "tcp")]
Tcp,
@ -20,7 +20,7 @@ impl Default for TransportType {
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ClientServiceConfig {
#[serde(rename = "type", default = "default_service_type")]
pub service_type: ServiceType,
@ -30,7 +30,7 @@ pub struct ClientServiceConfig {
pub token: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
pub enum ServiceType {
#[serde(rename = "tcp")]
Tcp,
@ -42,7 +42,7 @@ fn default_service_type() -> ServiceType {
ServiceType::Tcp
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ServerServiceConfig {
#[serde(rename = "type", default = "default_service_type")]
pub service_type: ServiceType,
@ -52,7 +52,7 @@ pub struct ServerServiceConfig {
pub token: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct TlsConfig {
pub hostname: Option<String>,
pub trusted_root: Option<String>,
@ -64,7 +64,7 @@ fn default_noise_pattern() -> String {
String::from("Noise_NK_25519_ChaChaPoly_BLAKE2s")
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct NoiseConfig {
#[serde(default = "default_noise_pattern")]
pub pattern: String,
@ -73,7 +73,7 @@ pub struct NoiseConfig {
// TODO: Maybe psk can be added
}
#[derive(Debug, Serialize, Deserialize, Default)]
#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
pub struct TransportConfig {
#[serde(rename = "type")]
pub transport_type: TransportType,
@ -85,7 +85,7 @@ fn default_transport() -> TransportConfig {
Default::default()
}
#[derive(Debug, Serialize, Deserialize, Default)]
#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
pub struct ClientConfig {
pub remote_addr: String,
pub default_token: Option<String>,
@ -94,7 +94,7 @@ pub struct ClientConfig {
pub transport: TransportConfig,
}
#[derive(Debug, Serialize, Deserialize, Default)]
#[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)]
pub struct ServerConfig {
pub bind_addr: String,
pub default_token: Option<String>,
@ -103,7 +103,7 @@ pub struct ServerConfig {
pub transport: TransportConfig,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(deny_unknown_fields)]
pub struct Config {
pub server: Option<ServerConfig>,

119
src/config_watcher.rs Normal file
View File

@ -0,0 +1,119 @@
use crate::{
config::{ClientServiceConfig, ServerServiceConfig},
Config,
};
use anyhow::{Context, Result};
use notify::{EventKind, RecursiveMode, Watcher};
use std::path::PathBuf;
use tokio::sync::{broadcast, mpsc};
use tracing::{error, info, instrument};
#[derive(Debug)]
pub enum ConfigChangeEvent {
General(Config), // Trigger a full restart
ServiceChange(ServiceChangeEvent),
}
#[derive(Debug)]
pub enum ServiceChangeEvent {
AddClientService(ClientServiceConfig),
DeleteClientService(ClientServiceConfig),
AddServerService(ServerServiceConfig),
DeleteServerService(ServerServiceConfig),
}
pub struct ConfigWatcherHandle {
pub event_rx: mpsc::Receiver<ConfigChangeEvent>,
}
impl ConfigWatcherHandle {
pub async fn new(path: &PathBuf, shutdown_rx: broadcast::Receiver<bool>) -> Result<Self> {
let (event_tx, event_rx) = mpsc::channel(16);
let origin_cfg = Config::from_file(path).await?;
tokio::spawn(config_watcher(
path.to_owned(),
shutdown_rx,
event_tx,
origin_cfg,
));
Ok(ConfigWatcherHandle { event_rx })
}
}
#[instrument(skip(shutdown_rx, cfg_event_tx))]
async fn config_watcher(
path: PathBuf,
mut shutdown_rx: broadcast::Receiver<bool>,
cfg_event_tx: mpsc::Sender<ConfigChangeEvent>,
mut old: Config,
) -> Result<()> {
let (fevent_tx, mut fevent_rx) = mpsc::channel(16);
let mut watcher = notify::recommended_watcher(move |res| match res {
Ok(event) => {
let _ = fevent_tx.blocking_send(event);
}
Err(e) => error!("watch error: {:?}", e),
})?;
// Initial start
cfg_event_tx
.send(ConfigChangeEvent::General(old.clone()))
.await
.unwrap();
watcher.watch(&path, RecursiveMode::NonRecursive)?;
info!("Start watching the config");
loop {
tokio::select! {
e = fevent_rx.recv() => {
match e {
Some(e) => {
match e.kind {
EventKind::Modify(_) => {
info!("Configuration modify event is detected");
let new = match Config::from_file(&path).await.with_context(|| "The changed configuration is invalid. Ignored") {
Ok(v) => v,
Err(e) => {
error!("{:?}", e);
// If the config is invalid, just ignore it
continue;
}
};
for event in calculate_event(&old, &new) {
cfg_event_tx.send(event).await?;
}
old = new;
},
_ => (), // Just ignore other events
}
},
None => break
}
},
_ = shutdown_rx.recv() => break
}
}
info!("Config watcher exiting");
Ok(())
}
fn calculate_event(old: &Config, new: &Config) -> Vec<ConfigChangeEvent> {
let mut ret = Vec::new();
if old == new {
return ret;
}
ret.push(ConfigChangeEvent::General(new.to_owned()));
ret
}

View File

@ -1,5 +1,6 @@
mod cli;
mod config;
mod config_watcher;
mod constants;
mod helper;
mod multi_map;
@ -9,10 +10,11 @@ mod transport;
pub use cli::Cli;
use cli::KeypairType;
pub use config::Config;
use config_watcher::ServiceChangeEvent;
pub use constants::UDP_BUFFER_SIZE;
use anyhow::{anyhow, Result};
use tokio::sync::broadcast;
use anyhow::Result;
use tokio::sync::{broadcast, mpsc};
use tracing::debug;
#[cfg(feature = "client")]
@ -25,6 +27,8 @@ mod server;
#[cfg(feature = "server")]
use server::run_server;
use crate::config_watcher::{ConfigChangeEvent, ConfigWatcherHandle};
const DEFAULT_CURVE: KeypairType = KeypairType::X25519;
fn get_str_from_keypair_type(curve: KeypairType) -> &'static str {
@ -56,20 +60,68 @@ fn genkey(curve: Option<KeypairType>) -> Result<()> {
crate::helper::feature_not_compile("nosie")
}
pub async fn run(args: &Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
if args.genkey.is_some() {
return genkey(args.genkey.unwrap());
}
let config = Config::from_file(args.config_path.as_ref().unwrap()).await?;
debug!("{:?}", config);
// Raise `nofile` limit on linux and mac
fdlimit::raise_fd_limit();
match determine_run_mode(&config, args) {
RunMode::Undetermine => Err(anyhow!("Cannot determine running as a server or a client")),
// Spawn a config watcher. The watcher will send a initial signal to start the instance with a config
let config_path = args.config_path.as_ref().unwrap();
let mut cfg_watcher = ConfigWatcherHandle::new(config_path, shutdown_rx).await?;
// shutdown_tx owns the instance
let (shutdown_tx, _) = broadcast::channel(1);
// (The join handle of the last instance, The service update channel sender)
let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender<ServiceChangeEvent>)> =
None;
while let Some(e) = cfg_watcher.event_rx.recv().await {
match e {
ConfigChangeEvent::General(config) => {
match last_instance {
Some((i, _)) => {
shutdown_tx.send(true)?;
i.await??;
}
None => (),
}
debug!("{:?}", config);
let (service_update_tx, service_update_rx) = mpsc::channel(1024);
last_instance = Some((
tokio::spawn(run_instance(
config.clone(),
args.clone(),
shutdown_tx.subscribe(),
service_update_rx,
)),
service_update_tx,
));
}
ConfigChangeEvent::ServiceChange(service_event) => {
if let Some((_, service_update_tx)) = &last_instance {
let _ = service_update_tx.send(service_event).await;
}
}
}
}
Ok(())
}
async fn run_instance(
config: Config,
args: Cli,
shutdown_rx: broadcast::Receiver<bool>,
_service_update: mpsc::Receiver<ServiceChangeEvent>,
) -> Result<()> {
match determine_run_mode(&config, &args) {
RunMode::Undetermine => panic!("Cannot determine running as a server or a client"),
RunMode::Client => {
#[cfg(not(feature = "client"))]
crate::helper::feature_not_compile("client");

View File

@ -29,5 +29,5 @@ async fn main() -> Result<()> {
)
.init();
run(&args, shutdown_rx).await
run(args, shutdown_rx).await
}

View File

@ -20,7 +20,7 @@ pub async fn run_rathole_server(
client: false,
..Default::default()
};
rathole::run(&cli, shutdown_rx).await
rathole::run(cli, shutdown_rx).await
}
pub async fn run_rathole_client(
@ -33,7 +33,7 @@ pub async fn run_rathole_client(
client: true,
..Default::default()
};
rathole::run(&cli, shutdown_rx).await
rathole::run(cli, shutdown_rx).await
}
pub mod tcp {

View File

@ -94,7 +94,7 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> {
// Start the client
info!("start the client");
let client = tokio::spawn(async move {
run_rathole_client(&config_path, client_shutdown_rx)
run_rathole_client(config_path, client_shutdown_rx)
.await
.unwrap();
});
@ -105,7 +105,7 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> {
// Start the server
info!("start the server");
let server = tokio::spawn(async move {
run_rathole_server(&config_path, server_shutdown_rx)
run_rathole_server(config_path, server_shutdown_rx)
.await
.unwrap();
});
@ -126,7 +126,7 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> {
info!("restart the client");
let client_shutdown_rx = client_shutdown_tx.subscribe();
let client = tokio::spawn(async move {
run_rathole_client(&config_path, client_shutdown_rx)
run_rathole_client(config_path, client_shutdown_rx)
.await
.unwrap();
});
@ -147,7 +147,7 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> {
info!("restart the server");
let server_shutdown_rx = server_shutdown_tx.subscribe();
let server = tokio::spawn(async move {
run_rathole_server(&config_path, server_shutdown_rx)
run_rathole_server(config_path, server_shutdown_rx)
.await
.unwrap();
});