rathole/src/client.rs

561 lines
19 KiB
Rust

use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType};
use crate::config_watcher::{ClientServiceChange, ConfigChange};
use crate::helper::udp_connect;
use crate::protocol::Hello::{self, *};
use crate::protocol::{
self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
DataChannelCmd, UdpTraffic, CURRENT_PROTO_VERSION, HASH_WIDTH_IN_BYTES,
};
use crate::transport::{AddrMaybeCached, SocketOpts, TcpTransport, Transport};
use anyhow::{anyhow, bail, Context, Result};
use backoff::backoff::Backoff;
use backoff::future::retry_notify;
use backoff::ExponentialBackoff;
use bytes::{Bytes, BytesMut};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UdpSocket};
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use tokio::time::{self, Duration, Instant};
use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span};
#[cfg(feature = "noise")]
use crate::transport::NoiseTransport;
#[cfg(any(feature = "native-tls", feature = "rustls"))]
use crate::transport::TlsTransport;
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
use crate::transport::WebsocketTransport;
use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
// The entrypoint of running a client
pub async fn run_client(
config: Config,
shutdown_rx: broadcast::Receiver<bool>,
update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
let config = config.client.ok_or_else(|| {
anyhow!(
"Try to run as a client, but the configuration is missing. Please add the `[client]` block"
)
})?;
match config.transport.transport_type {
TransportType::Tcp => {
let mut client = Client::<TcpTransport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
}
TransportType::Tls => {
#[cfg(any(feature = "native-tls", feature = "rustls"))]
{
let mut client = Client::<TlsTransport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
}
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
crate::helper::feature_neither_compile("native-tls", "rustls")
}
TransportType::Noise => {
#[cfg(feature = "noise")]
{
let mut client = Client::<NoiseTransport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
}
#[cfg(not(feature = "noise"))]
crate::helper::feature_not_compile("noise")
}
TransportType::Websocket => {
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
{
let mut client = Client::<WebsocketTransport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
}
#[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))]
crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls")
}
}
}
type ServiceDigest = protocol::Digest;
type Nonce = protocol::Digest;
// Holds the state of a client
struct Client<T: Transport> {
config: ClientConfig,
service_handles: HashMap<String, ControlChannelHandle>,
transport: Arc<T>,
}
impl<T: 'static + Transport> Client<T> {
// Create a Client from `[client]` config block
async fn from(config: ClientConfig) -> Result<Client<T>> {
let transport =
Arc::new(T::new(&config.transport).with_context(|| "Failed to create the transport")?);
Ok(Client {
config,
service_handles: HashMap::new(),
transport,
})
}
// The entrypoint of Client
async fn run(
&mut self,
mut shutdown_rx: broadcast::Receiver<bool>,
mut update_rx: mpsc::Receiver<ConfigChange>,
) -> Result<()> {
for (name, config) in &self.config.services {
// Create a control channel for each service defined
let handle = ControlChannelHandle::new(
(*config).clone(),
self.config.remote_addr.clone(),
self.transport.clone(),
self.config.heartbeat_timeout,
);
self.service_handles.insert(name.clone(), handle);
}
// Wait for the shutdown signal
loop {
tokio::select! {
val = shutdown_rx.recv() => {
match val {
Ok(_) => {}
Err(err) => {
error!("Unable to listen for shutdown signal: {}", err);
}
}
break;
},
e = update_rx.recv() => {
if let Some(e) = e {
self.handle_hot_reload(e).await;
}
}
}
}
// Shutdown all services
for (_, handle) in self.service_handles.drain() {
handle.shutdown();
}
Ok(())
}
async fn handle_hot_reload(&mut self, e: ConfigChange) {
match e {
ConfigChange::ClientChange(client_change) => match client_change {
ClientServiceChange::Add(cfg) => {
let name = cfg.name.clone();
let handle = ControlChannelHandle::new(
cfg,
self.config.remote_addr.clone(),
self.transport.clone(),
self.config.heartbeat_timeout,
);
let _ = self.service_handles.insert(name, handle);
}
ClientServiceChange::Delete(s) => {
let _ = self.service_handles.remove(&s);
}
},
ignored => warn!("Ignored {:?} since running as a client", ignored),
}
}
}
struct RunDataChannelArgs<T: Transport> {
session_key: Nonce,
remote_addr: AddrMaybeCached,
connector: Arc<T>,
socket_opts: SocketOpts,
service: ClientServiceConfig,
}
async fn do_data_channel_handshake<T: Transport>(
args: Arc<RunDataChannelArgs<T>>,
) -> Result<T::Stream> {
// Retry at least every 100ms, at most for 10 seconds
let backoff = ExponentialBackoff {
max_interval: Duration::from_millis(100),
max_elapsed_time: Some(Duration::from_secs(10)),
..Default::default()
};
// Connect to remote_addr
let mut conn: T::Stream = retry_notify(
backoff,
|| async {
args.connector
.connect(&args.remote_addr)
.await
.with_context(|| format!("Failed to connect to {}", &args.remote_addr))
.map_err(backoff::Error::transient)
},
|e, duration| {
warn!("{:#}. Retry in {:?}", e, duration);
},
)
.await?;
T::hint(&conn, args.socket_opts);
// Send nonce
let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap();
let hello = Hello::DataChannelHello(CURRENT_PROTO_VERSION, v.to_owned());
conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
conn.flush().await?;
Ok(conn)
}
async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
// Do the handshake
let mut conn = do_data_channel_handshake(args.clone()).await?;
// Forward
match read_data_cmd(&mut conn).await? {
DataChannelCmd::StartForwardTcp => {
if args.service.service_type != ServiceType::Tcp {
bail!("Expect TCP traffic. Please check the configuration.")
}
run_data_channel_for_tcp::<T>(conn, &args.service.local_addr).await?;
}
DataChannelCmd::StartForwardUdp => {
if args.service.service_type != ServiceType::Udp {
bail!("Expect UDP traffic. Please check the configuration.")
}
run_data_channel_for_udp::<T>(conn, &args.service.local_addr).await?;
}
}
Ok(())
}
// Simply copying back and forth for TCP
#[instrument(skip(conn))]
async fn run_data_channel_for_tcp<T: Transport>(
mut conn: T::Stream,
local_addr: &str,
) -> Result<()> {
debug!("New data channel starts forwarding");
let mut local = TcpStream::connect(local_addr)
.await
.with_context(|| format!("Failed to connect to {}", local_addr))?;
let _ = copy_bidirectional(&mut conn, &mut local).await;
Ok(())
}
// Things get a little tricker when it gets to UDP because it's connection-less.
// A UdpPortMap must be maintained for recent seen incoming address, giving them
// each a local port, which is associated with a socket. So just the sender
// to the socket will work fine for the map's value.
type UdpPortMap = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>;
#[instrument(skip(conn))]
async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &str) -> Result<()> {
debug!("New data channel starts forwarding");
let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new()));
// The channel stores UdpTraffic that needs to be sent to the server
let (outbound_tx, mut outbound_rx) = mpsc::channel::<UdpTraffic>(UDP_SENDQ_SIZE);
// FIXME: https://github.com/tokio-rs/tls/issues/40
// Maybe this is our concern
let (mut rd, mut wr) = io::split(conn);
// Keep sending items from the outbound channel to the server
tokio::spawn(async move {
while let Some(t) = outbound_rx.recv().await {
trace!("outbound {:?}", t);
if let Err(e) = t
.write(&mut wr)
.await
.with_context(|| "Failed to forward UDP traffic to the server")
{
debug!("{:?}", e);
break;
}
}
});
loop {
// Read a packet from the server
let hdr_len = rd.read_u8().await?;
let packet = UdpTraffic::read(&mut rd, hdr_len)
.await
.with_context(|| "Failed to read UDPTraffic from the server")?;
let m = port_map.read().await;
if m.get(&packet.from).is_none() {
// This packet is from a address we don't see for a while,
// which is not in the UdpPortMap.
// So set up a mapping (and a forwarder) for it
// Drop the reader lock
drop(m);
// Grab the writer lock
// This is the only thread that will try to grab the writer lock
// So no need to worry about some other thread has already set up
// the mapping between the gap of dropping the reader lock and
// grabbing the writer lock
let mut m = port_map.write().await;
match udp_connect(local_addr).await {
Ok(s) => {
let (inbound_tx, inbound_rx) = mpsc::channel(UDP_SENDQ_SIZE);
m.insert(packet.from, inbound_tx);
tokio::spawn(run_udp_forwarder(
s,
inbound_rx,
outbound_tx.clone(),
packet.from,
port_map.clone(),
));
}
Err(e) => {
error!("{:#}", e);
}
}
}
// Now there should be a udp forwarder that can receive the packet
let m = port_map.read().await;
if let Some(tx) = m.get(&packet.from) {
let _ = tx.send(packet.data).await;
}
}
}
// Run a UdpSocket for the visitor `from`
#[instrument(skip_all, fields(from))]
async fn run_udp_forwarder(
s: UdpSocket,
mut inbound_rx: mpsc::Receiver<Bytes>,
outbount_tx: mpsc::Sender<UdpTraffic>,
from: SocketAddr,
port_map: UdpPortMap,
) -> Result<()> {
debug!("Forwarder created");
let mut buf = BytesMut::new();
buf.resize(UDP_BUFFER_SIZE, 0);
loop {
tokio::select! {
// Receive from the server
data = inbound_rx.recv() => {
if let Some(data) = data {
s.send(&data).await?;
} else {
break;
}
},
// Receive from the service
val = s.recv(&mut buf) => {
let len = match val {
Ok(v) => v,
Err(_) => break
};
let t = UdpTraffic{
from,
data: Bytes::copy_from_slice(&buf[..len])
};
outbount_tx.send(t).await?;
},
// No traffic for the duration of UDP_TIMEOUT, clean up the state
_ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => {
break;
}
}
}
let mut port_map = port_map.write().await;
port_map.remove(&from);
debug!("Forwarder dropped");
Ok(())
}
// Control channel, using T as the transport layer
struct ControlChannel<T: Transport> {
digest: ServiceDigest, // SHA256 of the service name
service: ClientServiceConfig, // `[client.services.foo]` config block
shutdown_rx: oneshot::Receiver<u8>, // Receives the shutdown signal
remote_addr: String, // `client.remote_addr`
transport: Arc<T>, // Wrapper around the transport layer
heartbeat_timeout: u64, // Application layer heartbeat timeout in secs
}
// Handle of a control channel
// Dropping it will also drop the actual control channel
struct ControlChannelHandle {
shutdown_tx: oneshot::Sender<u8>,
}
impl<T: 'static + Transport> ControlChannel<T> {
#[instrument(skip_all)]
async fn run(&mut self) -> Result<()> {
let mut remote_addr = AddrMaybeCached::new(&self.remote_addr);
remote_addr.resolve().await?;
let mut conn = self
.transport
.connect(&remote_addr)
.await
.with_context(|| format!("Failed to connect to {}", &self.remote_addr))?;
T::hint(&conn, SocketOpts::for_control_channel());
// Send hello
debug!("Sending hello");
let hello_send =
Hello::ControlChannelHello(CURRENT_PROTO_VERSION, self.digest[..].try_into().unwrap());
conn.write_all(&bincode::serialize(&hello_send).unwrap())
.await?;
conn.flush().await?;
// Read hello
debug!("Reading hello");
let nonce = match read_hello(&mut conn).await? {
ControlChannelHello(_, d) => d,
_ => {
bail!("Unexpected type of hello");
}
};
// Send auth
debug!("Sending auth");
let mut concat = Vec::from(self.service.token.as_ref().unwrap().as_bytes());
concat.extend_from_slice(&nonce);
let session_key = protocol::digest(&concat);
let auth = Auth(session_key);
conn.write_all(&bincode::serialize(&auth).unwrap()).await?;
conn.flush().await?;
// Read ack
debug!("Reading ack");
match read_ack(&mut conn).await? {
Ack::Ok => {}
v => {
return Err(anyhow!("{}", v))
.with_context(|| format!("Authentication failed: {}", self.service.name));
}
}
// Channel ready
info!("Control channel established");
// Socket options for the data channel
let socket_opts = SocketOpts::from_client_cfg(&self.service);
let data_ch_args = Arc::new(RunDataChannelArgs {
session_key,
remote_addr,
connector: self.transport.clone(),
socket_opts,
service: self.service.clone(),
});
loop {
tokio::select! {
val = read_control_cmd(&mut conn) => {
let val = val?;
debug!( "Received {:?}", val);
match val {
ControlChannelCmd::CreateDataChannel => {
let args = data_ch_args.clone();
tokio::spawn(async move {
if let Err(e) = run_data_channel(args).await.with_context(|| "Failed to run the data channel") {
warn!("{:#}", e);
}
}.instrument(Span::current()));
},
ControlChannelCmd::HeartBeat => ()
}
},
_ = time::sleep(Duration::from_secs(self.heartbeat_timeout)), if self.heartbeat_timeout != 0 => {
return Err(anyhow!("Heartbeat timed out"))
}
_ = &mut self.shutdown_rx => {
break;
}
}
}
info!("Control channel shutdown");
Ok(())
}
}
impl ControlChannelHandle {
#[instrument(name="handle", skip_all, fields(service = %service.name))]
fn new<T: 'static + Transport>(
service: ClientServiceConfig,
remote_addr: String,
transport: Arc<T>,
heartbeat_timeout: u64,
) -> ControlChannelHandle {
let digest = protocol::digest(service.name.as_bytes());
info!("Starting {}", hex::encode(digest));
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let mut retry_backoff = run_control_chan_backoff(service.retry_interval.unwrap());
let mut s = ControlChannel {
digest,
service,
shutdown_rx,
remote_addr,
transport,
heartbeat_timeout,
};
tokio::spawn(
async move {
let mut start = Instant::now();
while let Err(err) = s
.run()
.await
.with_context(|| "Failed to run the control channel")
{
if s.shutdown_rx.try_recv() != Err(oneshot::error::TryRecvError::Empty) {
break;
}
if start.elapsed() > Duration::from_secs(3) {
// The client runs for at least 3 secs and then disconnects
retry_backoff.reset();
}
if let Some(duration) = retry_backoff.next_backoff() {
error!("{:#}. Retry in {:?}...", err, duration);
time::sleep(duration).await;
} else {
// Should never reach
panic!("{:#}. Break", err);
}
start = Instant::now();
}
}
.instrument(Span::current()),
);
ControlChannelHandle { shutdown_tx }
}
fn shutdown(self) {
// A send failure shows that the actor has already shutdown.
let _ = self.shutdown_tx.send(0u8);
}
}