use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType}; use crate::config_watcher::{ClientServiceChange, ConfigChange}; use crate::helper::udp_connect; use crate::protocol::Hello::{self, *}; use crate::protocol::{ self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd, DataChannelCmd, UdpTraffic, CURRENT_PROTO_VERSION, HASH_WIDTH_IN_BYTES, }; use crate::transport::{AddrMaybeCached, SocketOpts, TcpTransport, Transport}; use anyhow::{anyhow, bail, Context, Result}; use backoff::backoff::Backoff; use backoff::future::retry_notify; use backoff::ExponentialBackoff; use bytes::{Bytes, BytesMut}; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpStream, UdpSocket}; use tokio::sync::{broadcast, mpsc, oneshot, RwLock}; use tokio::time::{self, Duration, Instant}; use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span}; #[cfg(feature = "noise")] use crate::transport::NoiseTransport; #[cfg(any(feature = "native-tls", feature = "rustls"))] use crate::transport::TlsTransport; #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))] use crate::transport::WebsocketTransport; use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT}; // The entrypoint of running a client pub async fn run_client( config: Config, shutdown_rx: broadcast::Receiver, update_rx: mpsc::Receiver, ) -> Result<()> { let config = config.client.ok_or_else(|| { anyhow!( "Try to run as a client, but the configuration is missing. Please add the `[client]` block" ) })?; match config.transport.transport_type { TransportType::Tcp => { let mut client = Client::::from(config).await?; client.run(shutdown_rx, update_rx).await } TransportType::Tls => { #[cfg(any(feature = "native-tls", feature = "rustls"))] { let mut client = Client::::from(config).await?; client.run(shutdown_rx, update_rx).await } #[cfg(not(any(feature = "native-tls", feature = "rustls")))] crate::helper::feature_neither_compile("native-tls", "rustls") } TransportType::Noise => { #[cfg(feature = "noise")] { let mut client = Client::::from(config).await?; client.run(shutdown_rx, update_rx).await } #[cfg(not(feature = "noise"))] crate::helper::feature_not_compile("noise") } TransportType::Websocket => { #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))] { let mut client = Client::::from(config).await?; client.run(shutdown_rx, update_rx).await } #[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))] crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls") } } } type ServiceDigest = protocol::Digest; type Nonce = protocol::Digest; // Holds the state of a client struct Client { config: ClientConfig, service_handles: HashMap, transport: Arc, } impl Client { // Create a Client from `[client]` config block async fn from(config: ClientConfig) -> Result> { let transport = Arc::new(T::new(&config.transport).with_context(|| "Failed to create the transport")?); Ok(Client { config, service_handles: HashMap::new(), transport, }) } // The entrypoint of Client async fn run( &mut self, mut shutdown_rx: broadcast::Receiver, mut update_rx: mpsc::Receiver, ) -> Result<()> { for (name, config) in &self.config.services { // Create a control channel for each service defined let handle = ControlChannelHandle::new( (*config).clone(), self.config.remote_addr.clone(), self.transport.clone(), self.config.heartbeat_timeout, ); self.service_handles.insert(name.clone(), handle); } // Wait for the shutdown signal loop { tokio::select! { val = shutdown_rx.recv() => { match val { Ok(_) => {} Err(err) => { error!("Unable to listen for shutdown signal: {}", err); } } break; }, e = update_rx.recv() => { if let Some(e) = e { self.handle_hot_reload(e).await; } } } } // Shutdown all services for (_, handle) in self.service_handles.drain() { handle.shutdown(); } Ok(()) } async fn handle_hot_reload(&mut self, e: ConfigChange) { match e { ConfigChange::ClientChange(client_change) => match client_change { ClientServiceChange::Add(cfg) => { let name = cfg.name.clone(); let handle = ControlChannelHandle::new( cfg, self.config.remote_addr.clone(), self.transport.clone(), self.config.heartbeat_timeout, ); let _ = self.service_handles.insert(name, handle); } ClientServiceChange::Delete(s) => { let _ = self.service_handles.remove(&s); } }, ignored => warn!("Ignored {:?} since running as a client", ignored), } } } struct RunDataChannelArgs { session_key: Nonce, remote_addr: AddrMaybeCached, connector: Arc, socket_opts: SocketOpts, service: ClientServiceConfig, } async fn do_data_channel_handshake( args: Arc>, ) -> Result { // Retry at least every 100ms, at most for 10 seconds let backoff = ExponentialBackoff { max_interval: Duration::from_millis(100), max_elapsed_time: Some(Duration::from_secs(10)), ..Default::default() }; // Connect to remote_addr let mut conn: T::Stream = retry_notify( backoff, || async { args.connector .connect(&args.remote_addr) .await .with_context(|| format!("Failed to connect to {}", &args.remote_addr)) .map_err(backoff::Error::transient) }, |e, duration| { warn!("{:#}. Retry in {:?}", e, duration); }, ) .await?; T::hint(&conn, args.socket_opts); // Send nonce let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap(); let hello = Hello::DataChannelHello(CURRENT_PROTO_VERSION, v.to_owned()); conn.write_all(&bincode::serialize(&hello).unwrap()).await?; conn.flush().await?; Ok(conn) } async fn run_data_channel(args: Arc>) -> Result<()> { // Do the handshake let mut conn = do_data_channel_handshake(args.clone()).await?; // Forward match read_data_cmd(&mut conn).await? { DataChannelCmd::StartForwardTcp => { if args.service.service_type != ServiceType::Tcp { bail!("Expect TCP traffic. Please check the configuration.") } run_data_channel_for_tcp::(conn, &args.service.local_addr).await?; } DataChannelCmd::StartForwardUdp => { if args.service.service_type != ServiceType::Udp { bail!("Expect UDP traffic. Please check the configuration.") } run_data_channel_for_udp::(conn, &args.service.local_addr).await?; } } Ok(()) } // Simply copying back and forth for TCP #[instrument(skip(conn))] async fn run_data_channel_for_tcp( mut conn: T::Stream, local_addr: &str, ) -> Result<()> { debug!("New data channel starts forwarding"); let mut local = TcpStream::connect(local_addr) .await .with_context(|| format!("Failed to connect to {}", local_addr))?; let _ = copy_bidirectional(&mut conn, &mut local).await; Ok(()) } // Things get a little tricker when it gets to UDP because it's connection-less. // A UdpPortMap must be maintained for recent seen incoming address, giving them // each a local port, which is associated with a socket. So just the sender // to the socket will work fine for the map's value. type UdpPortMap = Arc>>>; #[instrument(skip(conn))] async fn run_data_channel_for_udp(conn: T::Stream, local_addr: &str) -> Result<()> { debug!("New data channel starts forwarding"); let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new())); // The channel stores UdpTraffic that needs to be sent to the server let (outbound_tx, mut outbound_rx) = mpsc::channel::(UDP_SENDQ_SIZE); // FIXME: https://github.com/tokio-rs/tls/issues/40 // Maybe this is our concern let (mut rd, mut wr) = io::split(conn); // Keep sending items from the outbound channel to the server tokio::spawn(async move { while let Some(t) = outbound_rx.recv().await { trace!("outbound {:?}", t); if let Err(e) = t .write(&mut wr) .await .with_context(|| "Failed to forward UDP traffic to the server") { debug!("{:?}", e); break; } } }); loop { // Read a packet from the server let hdr_len = rd.read_u8().await?; let packet = UdpTraffic::read(&mut rd, hdr_len) .await .with_context(|| "Failed to read UDPTraffic from the server")?; let m = port_map.read().await; if m.get(&packet.from).is_none() { // This packet is from a address we don't see for a while, // which is not in the UdpPortMap. // So set up a mapping (and a forwarder) for it // Drop the reader lock drop(m); // Grab the writer lock // This is the only thread that will try to grab the writer lock // So no need to worry about some other thread has already set up // the mapping between the gap of dropping the reader lock and // grabbing the writer lock let mut m = port_map.write().await; match udp_connect(local_addr).await { Ok(s) => { let (inbound_tx, inbound_rx) = mpsc::channel(UDP_SENDQ_SIZE); m.insert(packet.from, inbound_tx); tokio::spawn(run_udp_forwarder( s, inbound_rx, outbound_tx.clone(), packet.from, port_map.clone(), )); } Err(e) => { error!("{:#}", e); } } } // Now there should be a udp forwarder that can receive the packet let m = port_map.read().await; if let Some(tx) = m.get(&packet.from) { let _ = tx.send(packet.data).await; } } } // Run a UdpSocket for the visitor `from` #[instrument(skip_all, fields(from))] async fn run_udp_forwarder( s: UdpSocket, mut inbound_rx: mpsc::Receiver, outbount_tx: mpsc::Sender, from: SocketAddr, port_map: UdpPortMap, ) -> Result<()> { debug!("Forwarder created"); let mut buf = BytesMut::new(); buf.resize(UDP_BUFFER_SIZE, 0); loop { tokio::select! { // Receive from the server data = inbound_rx.recv() => { if let Some(data) = data { s.send(&data).await?; } else { break; } }, // Receive from the service val = s.recv(&mut buf) => { let len = match val { Ok(v) => v, Err(_) => break }; let t = UdpTraffic{ from, data: Bytes::copy_from_slice(&buf[..len]) }; outbount_tx.send(t).await?; }, // No traffic for the duration of UDP_TIMEOUT, clean up the state _ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => { break; } } } let mut port_map = port_map.write().await; port_map.remove(&from); debug!("Forwarder dropped"); Ok(()) } // Control channel, using T as the transport layer struct ControlChannel { digest: ServiceDigest, // SHA256 of the service name service: ClientServiceConfig, // `[client.services.foo]` config block shutdown_rx: oneshot::Receiver, // Receives the shutdown signal remote_addr: String, // `client.remote_addr` transport: Arc, // Wrapper around the transport layer heartbeat_timeout: u64, // Application layer heartbeat timeout in secs } // Handle of a control channel // Dropping it will also drop the actual control channel struct ControlChannelHandle { shutdown_tx: oneshot::Sender, } impl ControlChannel { #[instrument(skip_all)] async fn run(&mut self) -> Result<()> { let mut remote_addr = AddrMaybeCached::new(&self.remote_addr); remote_addr.resolve().await?; let mut conn = self .transport .connect(&remote_addr) .await .with_context(|| format!("Failed to connect to {}", &self.remote_addr))?; T::hint(&conn, SocketOpts::for_control_channel()); // Send hello debug!("Sending hello"); let hello_send = Hello::ControlChannelHello(CURRENT_PROTO_VERSION, self.digest[..].try_into().unwrap()); conn.write_all(&bincode::serialize(&hello_send).unwrap()) .await?; conn.flush().await?; // Read hello debug!("Reading hello"); let nonce = match read_hello(&mut conn).await? { ControlChannelHello(_, d) => d, _ => { bail!("Unexpected type of hello"); } }; // Send auth debug!("Sending auth"); let mut concat = Vec::from(self.service.token.as_ref().unwrap().as_bytes()); concat.extend_from_slice(&nonce); let session_key = protocol::digest(&concat); let auth = Auth(session_key); conn.write_all(&bincode::serialize(&auth).unwrap()).await?; conn.flush().await?; // Read ack debug!("Reading ack"); match read_ack(&mut conn).await? { Ack::Ok => {} v => { return Err(anyhow!("{}", v)) .with_context(|| format!("Authentication failed: {}", self.service.name)); } } // Channel ready info!("Control channel established"); // Socket options for the data channel let socket_opts = SocketOpts::from_client_cfg(&self.service); let data_ch_args = Arc::new(RunDataChannelArgs { session_key, remote_addr, connector: self.transport.clone(), socket_opts, service: self.service.clone(), }); loop { tokio::select! { val = read_control_cmd(&mut conn) => { let val = val?; debug!( "Received {:?}", val); match val { ControlChannelCmd::CreateDataChannel => { let args = data_ch_args.clone(); tokio::spawn(async move { if let Err(e) = run_data_channel(args).await.with_context(|| "Failed to run the data channel") { warn!("{:#}", e); } }.instrument(Span::current())); }, ControlChannelCmd::HeartBeat => () } }, _ = time::sleep(Duration::from_secs(self.heartbeat_timeout)), if self.heartbeat_timeout != 0 => { return Err(anyhow!("Heartbeat timed out")) } _ = &mut self.shutdown_rx => { break; } } } info!("Control channel shutdown"); Ok(()) } } impl ControlChannelHandle { #[instrument(name="handle", skip_all, fields(service = %service.name))] fn new( service: ClientServiceConfig, remote_addr: String, transport: Arc, heartbeat_timeout: u64, ) -> ControlChannelHandle { let digest = protocol::digest(service.name.as_bytes()); info!("Starting {}", hex::encode(digest)); let (shutdown_tx, shutdown_rx) = oneshot::channel(); let mut retry_backoff = run_control_chan_backoff(service.retry_interval.unwrap()); let mut s = ControlChannel { digest, service, shutdown_rx, remote_addr, transport, heartbeat_timeout, }; tokio::spawn( async move { let mut start = Instant::now(); while let Err(err) = s .run() .await .with_context(|| "Failed to run the control channel") { if s.shutdown_rx.try_recv() != Err(oneshot::error::TryRecvError::Empty) { break; } if start.elapsed() > Duration::from_secs(3) { // The client runs for at least 3 secs and then disconnects retry_backoff.reset(); } if let Some(duration) = retry_backoff.next_backoff() { error!("{:#}. Retry in {:?}...", err, duration); time::sleep(duration).await; } else { // Should never reach panic!("{:#}. Break", err); } start = Instant::now(); } } .instrument(Span::current()), ); ControlChannelHandle { shutdown_tx } } fn shutdown(self) { // A send failure shows that the actor has already shutdown. let _ = self.shutdown_tx.send(0u8); } }