mirror of https://github.com/rapiz1/rathole.git
561 lines
19 KiB
Rust
561 lines
19 KiB
Rust
use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType};
|
|
use crate::config_watcher::{ClientServiceChange, ConfigChange};
|
|
use crate::helper::udp_connect;
|
|
use crate::protocol::Hello::{self, *};
|
|
use crate::protocol::{
|
|
self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
|
|
DataChannelCmd, UdpTraffic, CURRENT_PROTO_VERSION, HASH_WIDTH_IN_BYTES,
|
|
};
|
|
use crate::transport::{AddrMaybeCached, SocketOpts, TcpTransport, Transport};
|
|
use anyhow::{anyhow, bail, Context, Result};
|
|
use backoff::backoff::Backoff;
|
|
use backoff::future::retry_notify;
|
|
use backoff::ExponentialBackoff;
|
|
use bytes::{Bytes, BytesMut};
|
|
use std::collections::HashMap;
|
|
use std::net::SocketAddr;
|
|
use std::sync::Arc;
|
|
use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::{TcpStream, UdpSocket};
|
|
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
|
|
use tokio::time::{self, Duration, Instant};
|
|
use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span};
|
|
|
|
#[cfg(feature = "noise")]
|
|
use crate::transport::NoiseTransport;
|
|
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
|
use crate::transport::TlsTransport;
|
|
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
|
|
use crate::transport::WebsocketTransport;
|
|
|
|
use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
|
|
|
|
// The entrypoint of running a client
|
|
pub async fn run_client(
|
|
config: Config,
|
|
shutdown_rx: broadcast::Receiver<bool>,
|
|
update_rx: mpsc::Receiver<ConfigChange>,
|
|
) -> Result<()> {
|
|
let config = config.client.ok_or_else(|| {
|
|
anyhow!(
|
|
"Try to run as a client, but the configuration is missing. Please add the `[client]` block"
|
|
)
|
|
})?;
|
|
|
|
match config.transport.transport_type {
|
|
TransportType::Tcp => {
|
|
let mut client = Client::<TcpTransport>::from(config).await?;
|
|
client.run(shutdown_rx, update_rx).await
|
|
}
|
|
TransportType::Tls => {
|
|
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
|
{
|
|
let mut client = Client::<TlsTransport>::from(config).await?;
|
|
client.run(shutdown_rx, update_rx).await
|
|
}
|
|
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
|
|
crate::helper::feature_neither_compile("native-tls", "rustls")
|
|
}
|
|
TransportType::Noise => {
|
|
#[cfg(feature = "noise")]
|
|
{
|
|
let mut client = Client::<NoiseTransport>::from(config).await?;
|
|
client.run(shutdown_rx, update_rx).await
|
|
}
|
|
#[cfg(not(feature = "noise"))]
|
|
crate::helper::feature_not_compile("noise")
|
|
}
|
|
TransportType::Websocket => {
|
|
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
|
|
{
|
|
let mut client = Client::<WebsocketTransport>::from(config).await?;
|
|
client.run(shutdown_rx, update_rx).await
|
|
}
|
|
#[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))]
|
|
crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls")
|
|
}
|
|
}
|
|
}
|
|
|
|
type ServiceDigest = protocol::Digest;
|
|
type Nonce = protocol::Digest;
|
|
|
|
// Holds the state of a client
|
|
struct Client<T: Transport> {
|
|
config: ClientConfig,
|
|
service_handles: HashMap<String, ControlChannelHandle>,
|
|
transport: Arc<T>,
|
|
}
|
|
|
|
impl<T: 'static + Transport> Client<T> {
|
|
// Create a Client from `[client]` config block
|
|
async fn from(config: ClientConfig) -> Result<Client<T>> {
|
|
let transport =
|
|
Arc::new(T::new(&config.transport).with_context(|| "Failed to create the transport")?);
|
|
Ok(Client {
|
|
config,
|
|
service_handles: HashMap::new(),
|
|
transport,
|
|
})
|
|
}
|
|
|
|
// The entrypoint of Client
|
|
async fn run(
|
|
&mut self,
|
|
mut shutdown_rx: broadcast::Receiver<bool>,
|
|
mut update_rx: mpsc::Receiver<ConfigChange>,
|
|
) -> Result<()> {
|
|
for (name, config) in &self.config.services {
|
|
// Create a control channel for each service defined
|
|
let handle = ControlChannelHandle::new(
|
|
(*config).clone(),
|
|
self.config.remote_addr.clone(),
|
|
self.transport.clone(),
|
|
self.config.heartbeat_timeout,
|
|
);
|
|
self.service_handles.insert(name.clone(), handle);
|
|
}
|
|
|
|
// Wait for the shutdown signal
|
|
loop {
|
|
tokio::select! {
|
|
val = shutdown_rx.recv() => {
|
|
match val {
|
|
Ok(_) => {}
|
|
Err(err) => {
|
|
error!("Unable to listen for shutdown signal: {}", err);
|
|
}
|
|
}
|
|
break;
|
|
},
|
|
e = update_rx.recv() => {
|
|
if let Some(e) = e {
|
|
self.handle_hot_reload(e).await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Shutdown all services
|
|
for (_, handle) in self.service_handles.drain() {
|
|
handle.shutdown();
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_hot_reload(&mut self, e: ConfigChange) {
|
|
match e {
|
|
ConfigChange::ClientChange(client_change) => match client_change {
|
|
ClientServiceChange::Add(cfg) => {
|
|
let name = cfg.name.clone();
|
|
let handle = ControlChannelHandle::new(
|
|
cfg,
|
|
self.config.remote_addr.clone(),
|
|
self.transport.clone(),
|
|
self.config.heartbeat_timeout,
|
|
);
|
|
let _ = self.service_handles.insert(name, handle);
|
|
}
|
|
ClientServiceChange::Delete(s) => {
|
|
let _ = self.service_handles.remove(&s);
|
|
}
|
|
},
|
|
ignored => warn!("Ignored {:?} since running as a client", ignored),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct RunDataChannelArgs<T: Transport> {
|
|
session_key: Nonce,
|
|
remote_addr: AddrMaybeCached,
|
|
connector: Arc<T>,
|
|
socket_opts: SocketOpts,
|
|
service: ClientServiceConfig,
|
|
}
|
|
|
|
async fn do_data_channel_handshake<T: Transport>(
|
|
args: Arc<RunDataChannelArgs<T>>,
|
|
) -> Result<T::Stream> {
|
|
// Retry at least every 100ms, at most for 10 seconds
|
|
let backoff = ExponentialBackoff {
|
|
max_interval: Duration::from_millis(100),
|
|
max_elapsed_time: Some(Duration::from_secs(10)),
|
|
..Default::default()
|
|
};
|
|
|
|
// Connect to remote_addr
|
|
let mut conn: T::Stream = retry_notify(
|
|
backoff,
|
|
|| async {
|
|
args.connector
|
|
.connect(&args.remote_addr)
|
|
.await
|
|
.with_context(|| format!("Failed to connect to {}", &args.remote_addr))
|
|
.map_err(backoff::Error::transient)
|
|
},
|
|
|e, duration| {
|
|
warn!("{:#}. Retry in {:?}", e, duration);
|
|
},
|
|
)
|
|
.await?;
|
|
|
|
T::hint(&conn, args.socket_opts);
|
|
|
|
// Send nonce
|
|
let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap();
|
|
let hello = Hello::DataChannelHello(CURRENT_PROTO_VERSION, v.to_owned());
|
|
conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
|
|
conn.flush().await?;
|
|
|
|
Ok(conn)
|
|
}
|
|
|
|
async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
|
|
// Do the handshake
|
|
let mut conn = do_data_channel_handshake(args.clone()).await?;
|
|
|
|
// Forward
|
|
match read_data_cmd(&mut conn).await? {
|
|
DataChannelCmd::StartForwardTcp => {
|
|
if args.service.service_type != ServiceType::Tcp {
|
|
bail!("Expect TCP traffic. Please check the configuration.")
|
|
}
|
|
run_data_channel_for_tcp::<T>(conn, &args.service.local_addr).await?;
|
|
}
|
|
DataChannelCmd::StartForwardUdp => {
|
|
if args.service.service_type != ServiceType::Udp {
|
|
bail!("Expect UDP traffic. Please check the configuration.")
|
|
}
|
|
run_data_channel_for_udp::<T>(conn, &args.service.local_addr).await?;
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
// Simply copying back and forth for TCP
|
|
#[instrument(skip(conn))]
|
|
async fn run_data_channel_for_tcp<T: Transport>(
|
|
mut conn: T::Stream,
|
|
local_addr: &str,
|
|
) -> Result<()> {
|
|
debug!("New data channel starts forwarding");
|
|
|
|
let mut local = TcpStream::connect(local_addr)
|
|
.await
|
|
.with_context(|| format!("Failed to connect to {}", local_addr))?;
|
|
let _ = copy_bidirectional(&mut conn, &mut local).await;
|
|
Ok(())
|
|
}
|
|
|
|
// Things get a little tricker when it gets to UDP because it's connection-less.
|
|
// A UdpPortMap must be maintained for recent seen incoming address, giving them
|
|
// each a local port, which is associated with a socket. So just the sender
|
|
// to the socket will work fine for the map's value.
|
|
type UdpPortMap = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>;
|
|
|
|
#[instrument(skip(conn))]
|
|
async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &str) -> Result<()> {
|
|
debug!("New data channel starts forwarding");
|
|
|
|
let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new()));
|
|
|
|
// The channel stores UdpTraffic that needs to be sent to the server
|
|
let (outbound_tx, mut outbound_rx) = mpsc::channel::<UdpTraffic>(UDP_SENDQ_SIZE);
|
|
|
|
// FIXME: https://github.com/tokio-rs/tls/issues/40
|
|
// Maybe this is our concern
|
|
let (mut rd, mut wr) = io::split(conn);
|
|
|
|
// Keep sending items from the outbound channel to the server
|
|
tokio::spawn(async move {
|
|
while let Some(t) = outbound_rx.recv().await {
|
|
trace!("outbound {:?}", t);
|
|
if let Err(e) = t
|
|
.write(&mut wr)
|
|
.await
|
|
.with_context(|| "Failed to forward UDP traffic to the server")
|
|
{
|
|
debug!("{:?}", e);
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
|
|
loop {
|
|
// Read a packet from the server
|
|
let hdr_len = rd.read_u8().await?;
|
|
let packet = UdpTraffic::read(&mut rd, hdr_len)
|
|
.await
|
|
.with_context(|| "Failed to read UDPTraffic from the server")?;
|
|
let m = port_map.read().await;
|
|
|
|
if m.get(&packet.from).is_none() {
|
|
// This packet is from a address we don't see for a while,
|
|
// which is not in the UdpPortMap.
|
|
// So set up a mapping (and a forwarder) for it
|
|
|
|
// Drop the reader lock
|
|
drop(m);
|
|
|
|
// Grab the writer lock
|
|
// This is the only thread that will try to grab the writer lock
|
|
// So no need to worry about some other thread has already set up
|
|
// the mapping between the gap of dropping the reader lock and
|
|
// grabbing the writer lock
|
|
let mut m = port_map.write().await;
|
|
|
|
match udp_connect(local_addr).await {
|
|
Ok(s) => {
|
|
let (inbound_tx, inbound_rx) = mpsc::channel(UDP_SENDQ_SIZE);
|
|
m.insert(packet.from, inbound_tx);
|
|
tokio::spawn(run_udp_forwarder(
|
|
s,
|
|
inbound_rx,
|
|
outbound_tx.clone(),
|
|
packet.from,
|
|
port_map.clone(),
|
|
));
|
|
}
|
|
Err(e) => {
|
|
error!("{:#}", e);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Now there should be a udp forwarder that can receive the packet
|
|
let m = port_map.read().await;
|
|
if let Some(tx) = m.get(&packet.from) {
|
|
let _ = tx.send(packet.data).await;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Run a UdpSocket for the visitor `from`
|
|
#[instrument(skip_all, fields(from))]
|
|
async fn run_udp_forwarder(
|
|
s: UdpSocket,
|
|
mut inbound_rx: mpsc::Receiver<Bytes>,
|
|
outbount_tx: mpsc::Sender<UdpTraffic>,
|
|
from: SocketAddr,
|
|
port_map: UdpPortMap,
|
|
) -> Result<()> {
|
|
debug!("Forwarder created");
|
|
let mut buf = BytesMut::new();
|
|
buf.resize(UDP_BUFFER_SIZE, 0);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
// Receive from the server
|
|
data = inbound_rx.recv() => {
|
|
if let Some(data) = data {
|
|
s.send(&data).await?;
|
|
} else {
|
|
break;
|
|
}
|
|
},
|
|
|
|
// Receive from the service
|
|
val = s.recv(&mut buf) => {
|
|
let len = match val {
|
|
Ok(v) => v,
|
|
Err(_) => break
|
|
};
|
|
|
|
let t = UdpTraffic{
|
|
from,
|
|
data: Bytes::copy_from_slice(&buf[..len])
|
|
};
|
|
|
|
outbount_tx.send(t).await?;
|
|
},
|
|
|
|
// No traffic for the duration of UDP_TIMEOUT, clean up the state
|
|
_ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut port_map = port_map.write().await;
|
|
port_map.remove(&from);
|
|
|
|
debug!("Forwarder dropped");
|
|
Ok(())
|
|
}
|
|
|
|
// Control channel, using T as the transport layer
|
|
struct ControlChannel<T: Transport> {
|
|
digest: ServiceDigest, // SHA256 of the service name
|
|
service: ClientServiceConfig, // `[client.services.foo]` config block
|
|
shutdown_rx: oneshot::Receiver<u8>, // Receives the shutdown signal
|
|
remote_addr: String, // `client.remote_addr`
|
|
transport: Arc<T>, // Wrapper around the transport layer
|
|
heartbeat_timeout: u64, // Application layer heartbeat timeout in secs
|
|
}
|
|
|
|
// Handle of a control channel
|
|
// Dropping it will also drop the actual control channel
|
|
struct ControlChannelHandle {
|
|
shutdown_tx: oneshot::Sender<u8>,
|
|
}
|
|
|
|
impl<T: 'static + Transport> ControlChannel<T> {
|
|
#[instrument(skip_all)]
|
|
async fn run(&mut self) -> Result<()> {
|
|
let mut remote_addr = AddrMaybeCached::new(&self.remote_addr);
|
|
remote_addr.resolve().await?;
|
|
|
|
let mut conn = self
|
|
.transport
|
|
.connect(&remote_addr)
|
|
.await
|
|
.with_context(|| format!("Failed to connect to {}", &self.remote_addr))?;
|
|
T::hint(&conn, SocketOpts::for_control_channel());
|
|
|
|
// Send hello
|
|
debug!("Sending hello");
|
|
let hello_send =
|
|
Hello::ControlChannelHello(CURRENT_PROTO_VERSION, self.digest[..].try_into().unwrap());
|
|
conn.write_all(&bincode::serialize(&hello_send).unwrap())
|
|
.await?;
|
|
conn.flush().await?;
|
|
|
|
// Read hello
|
|
debug!("Reading hello");
|
|
let nonce = match read_hello(&mut conn).await? {
|
|
ControlChannelHello(_, d) => d,
|
|
_ => {
|
|
bail!("Unexpected type of hello");
|
|
}
|
|
};
|
|
|
|
// Send auth
|
|
debug!("Sending auth");
|
|
let mut concat = Vec::from(self.service.token.as_ref().unwrap().as_bytes());
|
|
concat.extend_from_slice(&nonce);
|
|
|
|
let session_key = protocol::digest(&concat);
|
|
let auth = Auth(session_key);
|
|
conn.write_all(&bincode::serialize(&auth).unwrap()).await?;
|
|
conn.flush().await?;
|
|
|
|
// Read ack
|
|
debug!("Reading ack");
|
|
match read_ack(&mut conn).await? {
|
|
Ack::Ok => {}
|
|
v => {
|
|
return Err(anyhow!("{}", v))
|
|
.with_context(|| format!("Authentication failed: {}", self.service.name));
|
|
}
|
|
}
|
|
|
|
// Channel ready
|
|
info!("Control channel established");
|
|
|
|
// Socket options for the data channel
|
|
let socket_opts = SocketOpts::from_client_cfg(&self.service);
|
|
let data_ch_args = Arc::new(RunDataChannelArgs {
|
|
session_key,
|
|
remote_addr,
|
|
connector: self.transport.clone(),
|
|
socket_opts,
|
|
service: self.service.clone(),
|
|
});
|
|
|
|
loop {
|
|
tokio::select! {
|
|
val = read_control_cmd(&mut conn) => {
|
|
let val = val?;
|
|
debug!( "Received {:?}", val);
|
|
match val {
|
|
ControlChannelCmd::CreateDataChannel => {
|
|
let args = data_ch_args.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = run_data_channel(args).await.with_context(|| "Failed to run the data channel") {
|
|
warn!("{:#}", e);
|
|
}
|
|
}.instrument(Span::current()));
|
|
},
|
|
ControlChannelCmd::HeartBeat => ()
|
|
}
|
|
},
|
|
_ = time::sleep(Duration::from_secs(self.heartbeat_timeout)), if self.heartbeat_timeout != 0 => {
|
|
return Err(anyhow!("Heartbeat timed out"))
|
|
}
|
|
_ = &mut self.shutdown_rx => {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
info!("Control channel shutdown");
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl ControlChannelHandle {
|
|
#[instrument(name="handle", skip_all, fields(service = %service.name))]
|
|
fn new<T: 'static + Transport>(
|
|
service: ClientServiceConfig,
|
|
remote_addr: String,
|
|
transport: Arc<T>,
|
|
heartbeat_timeout: u64,
|
|
) -> ControlChannelHandle {
|
|
let digest = protocol::digest(service.name.as_bytes());
|
|
|
|
info!("Starting {}", hex::encode(digest));
|
|
let (shutdown_tx, shutdown_rx) = oneshot::channel();
|
|
|
|
let mut retry_backoff = run_control_chan_backoff(service.retry_interval.unwrap());
|
|
|
|
let mut s = ControlChannel {
|
|
digest,
|
|
service,
|
|
shutdown_rx,
|
|
remote_addr,
|
|
transport,
|
|
heartbeat_timeout,
|
|
};
|
|
|
|
tokio::spawn(
|
|
async move {
|
|
let mut start = Instant::now();
|
|
|
|
while let Err(err) = s
|
|
.run()
|
|
.await
|
|
.with_context(|| "Failed to run the control channel")
|
|
{
|
|
if s.shutdown_rx.try_recv() != Err(oneshot::error::TryRecvError::Empty) {
|
|
break;
|
|
}
|
|
|
|
if start.elapsed() > Duration::from_secs(3) {
|
|
// The client runs for at least 3 secs and then disconnects
|
|
retry_backoff.reset();
|
|
}
|
|
|
|
if let Some(duration) = retry_backoff.next_backoff() {
|
|
error!("{:#}. Retry in {:?}...", err, duration);
|
|
time::sleep(duration).await;
|
|
} else {
|
|
// Should never reach
|
|
panic!("{:#}. Break", err);
|
|
}
|
|
|
|
start = Instant::now();
|
|
}
|
|
}
|
|
.instrument(Span::current()),
|
|
);
|
|
|
|
ControlChannelHandle { shutdown_tx }
|
|
}
|
|
|
|
fn shutdown(self) {
|
|
// A send failure shows that the actor has already shutdown.
|
|
let _ = self.shutdown_tx.send(0u8);
|
|
}
|
|
}
|