refactor: fix clippy, merge imports

Fix lints of clippy
Merge imports
This commit is contained in:
Yujia Qiao 2021-12-18 16:23:43 +08:00
parent b3bdd7eb64
commit f92398ea31
No known key found for this signature in database
GPG Key ID: DC129173B148701B
9 changed files with 59 additions and 85 deletions

1
.rustfmt.toml Normal file
View File

@ -0,0 +1 @@
imports_granularity = "module"

View File

@ -2,12 +2,11 @@ use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType}; use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
use crate::protocol::Hello::{self, *};
use crate::protocol::{ use crate::protocol::{
self, Ack, Auth, ControlChannelCmd, DataChannelCmd, self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
Hello::{self, *}, DataChannelCmd, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
}; };
use crate::protocol::{read_ack, read_control_cmd, read_data_cmd, read_hello};
use crate::transport::{TcpTransport, TlsTransport, Transport}; use crate::transport::{TcpTransport, TlsTransport, Transport};
use anyhow::{anyhow, bail, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
use backoff::ExponentialBackoff; use backoff::ExponentialBackoff;
@ -28,11 +27,11 @@ pub async fn run_client(config: &Config) -> Result<()> {
match config.transport.transport_type { match config.transport.transport_type {
TransportType::Tcp => { TransportType::Tcp => {
let mut client = Client::<TcpTransport>::from(&config).await?; let mut client = Client::<TcpTransport>::from(config).await?;
client.run().await client.run().await
} }
TransportType::Tls => { TransportType::Tls => {
let mut client = Client::<TlsTransport>::from(&config).await?; let mut client = Client::<TlsTransport>::from(config).await?;
client.run().await client.run().await
} }
} }
@ -244,19 +243,14 @@ impl ControlChannelHandle {
tokio::spawn( tokio::spawn(
async move { async move {
loop { while let Err(err) = s
if let Err(err) = s .run()
.run() .await
.await .with_context(|| "Failed to run the control channel")
.with_context(|| "Failed to run the control channel") {
{ let duration = Duration::from_secs(2);
let duration = Duration::from_secs(2); error!("{:?}\n\nRetry in {:?}...", err, duration);
error!("{:?}\n\nRetry in {:?}...", err, duration); time::sleep(duration).await;
time::sleep(duration).await;
} else {
// Shutdown
break;
}
} }
} }
.instrument(Span::current()), .instrument(Span::current()),

View File

@ -1,9 +1,8 @@
use anyhow::{anyhow, bail, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::path::PathBuf; use std::path::Path;
use tokio::fs; use tokio::fs;
use toml;
#[derive(Debug, Serialize, Deserialize, Copy, Clone)] #[derive(Debug, Serialize, Deserialize, Copy, Clone)]
pub enum TransportType { pub enum TransportType {
@ -81,8 +80,7 @@ pub struct Config {
impl Config { impl Config {
fn from_str(s: &str) -> Result<Config> { fn from_str(s: &str) -> Result<Config> {
let mut config: Config = let mut config: Config = toml::from_str(s).with_context(|| "Failed to parse the config")?;
toml::from_str(&s).with_context(|| "Failed to parse the config")?;
if let Some(server) = config.server.as_mut() { if let Some(server) = config.server.as_mut() {
Config::validate_server_config(server)?; Config::validate_server_config(server)?;
@ -158,7 +156,7 @@ impl Config {
} }
} }
pub async fn from_file(path: &PathBuf) -> Result<Config> { pub async fn from_file(path: &Path) -> Result<Config> {
let s: String = fs::read_to_string(path) let s: String = fs::read_to_string(path)
.await .await
.with_context(|| format!("Failed to read the config {:?}", path))?; .with_context(|| format!("Failed to read the config {:?}", path))?;

View File

@ -26,7 +26,7 @@ pub async fn run(args: &Cli) -> Result<()> {
// Raise `nofile` limit on linux and mac // Raise `nofile` limit on linux and mac
fdlimit::raise_fd_limit(); fdlimit::raise_fd_limit();
match determine_run_mode(&config, &args) { match determine_run_mode(&config, args) {
RunMode::Undetermine => Err(anyhow!("Cannot determine running as a server or a client")), RunMode::Undetermine => Err(anyhow!("Cannot determine running as a server or a client")),
RunMode::Client => run_client(&config).await, RunMode::Client => run_client(&config).await,
RunMode::Server => run_server(&config).await, RunMode::Server => run_server(&config).await,
@ -44,20 +44,16 @@ fn determine_run_mode(config: &Config, args: &Cli) -> RunMode {
use RunMode::*; use RunMode::*;
if args.client && args.server { if args.client && args.server {
Undetermine Undetermine
} else if args.client {
Client
} else if args.server {
Server
} else if config.client.is_some() && config.server.is_none() {
Client
} else if config.server.is_some() && config.client.is_none() {
Server
} else { } else {
if args.client { Undetermine
Client
} else if args.server {
Server
} else {
if config.server.is_some() && config.client.is_none() {
Server
} else if config.client.is_some() && config.server.is_none() {
Client
} else {
Undetermine
}
}
} }
} }

View File

@ -1,7 +1,6 @@
use anyhow::Result; use anyhow::Result;
use clap::Parser; use clap::Parser;
use rathole::{run, Cli}; use rathole::{run, Cli};
use tokio;
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {

View File

@ -1,13 +1,13 @@
use crate::config::{Config, ServerConfig, ServerServiceConfig, TransportType}; use crate::config::{Config, ServerConfig, ServerServiceConfig, TransportType};
use crate::multi_map::MultiMap; use crate::multi_map::MultiMap;
use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
use crate::protocol::{ use crate::protocol::{
self, Ack, ControlChannelCmd, DataChannelCmd, Hello, Hello::ControlChannelHello, self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, HASH_WIDTH_IN_BYTES,
Hello::DataChannelHello, HASH_WIDTH_IN_BYTES,
}; };
use crate::protocol::{read_auth, read_hello};
use crate::transport::{TcpTransport, TlsTransport, Transport}; use crate::transport::{TcpTransport, TlsTransport, Transport};
use anyhow::{anyhow, bail, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
use backoff::{backoff::Backoff, ExponentialBackoff}; use backoff::backoff::Backoff;
use backoff::ExponentialBackoff;
use rand::RngCore; use rand::RngCore;
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
@ -15,8 +15,7 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::io::{self, copy_bidirectional, AsyncWriteExt}; use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc; use tokio::sync::{mpsc, oneshot, RwLock};
use tokio::sync::{oneshot, RwLock};
use tokio::time; use tokio::time;
use tracing::{debug, error, info, info_span, warn, Instrument}; use tracing::{debug, error, info, info_span, warn, Instrument};
@ -190,9 +189,7 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
concat.append(&mut nonce); concat.append(&mut nonce);
// Read auth // Read auth
let d = match read_auth(&mut conn).await? { let protocol::Auth(d) = read_auth(&mut conn).await?;
protocol::Auth(v) => v,
};
// Validate // Validate
let session_key = protocol::digest(&concat); let session_key = protocol::digest(&concat);
@ -259,13 +256,13 @@ struct ControlChannel<T: Transport> {
} }
struct ControlChannelHandle<T: Transport> { struct ControlChannelHandle<T: Transport> {
shutdown_tx: oneshot::Sender<bool>, _shutdown_tx: oneshot::Sender<bool>,
conn_pool: ConnectionPoolHandle<T>, conn_pool: ConnectionPoolHandle<T>,
} }
impl<T: 'static + Transport> ControlChannelHandle<T> { impl<T: 'static + Transport> ControlChannelHandle<T> {
fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> { fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
let (shutdown_tx, shutdown_rx) = oneshot::channel::<bool>(); let (_shutdown_tx, shutdown_rx) = oneshot::channel::<bool>();
let name = service.name.clone(); let name = service.name.clone();
let conn_pool = ConnectionPoolHandle::new(); let conn_pool = ConnectionPoolHandle::new();
let actor: ControlChannel<T> = ControlChannel { let actor: ControlChannel<T> = ControlChannel {
@ -282,7 +279,7 @@ impl<T: 'static + Transport> ControlChannelHandle<T> {
}); });
ControlChannelHandle { ControlChannelHandle {
shutdown_tx, _shutdown_tx,
conn_pool, conn_pool,
} }
} }
@ -309,7 +306,7 @@ impl<T: Transport> ControlChannel<T> {
let (data_req_tx, mut data_req_rx) = mpsc::unbounded_channel::<u8>(); let (data_req_tx, mut data_req_rx) = mpsc::unbounded_channel::<u8>();
tokio::spawn(async move { tokio::spawn(async move {
let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap(); let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
while let Some(_) = data_req_rx.recv().await { while data_req_rx.recv().await.is_some() {
if self.conn.write_all(&cmd).await.is_err() { if self.conn.write_all(&cmd).await.is_err() {
break; break;
} }
@ -396,18 +393,14 @@ impl<T: 'static + Transport> ConnectionPoolHandle<T> {
impl<T: Transport> ConnectionPool<T> { impl<T: Transport> ConnectionPool<T> {
#[tracing::instrument] #[tracing::instrument]
async fn run(mut self) { async fn run(mut self) {
loop { while let Some(mut visitor) = self.visitor_rx.recv().await {
if let Some(mut visitor) = self.visitor_rx.recv().await { if let Some(mut ch) = self.data_ch_rx.recv().await {
if let Some(mut ch) = self.data_ch_rx.recv().await { tokio::spawn(async move {
tokio::spawn(async move { let cmd = bincode::serialize(&DataChannelCmd::StartForward).unwrap();
let cmd = bincode::serialize(&DataChannelCmd::StartForward).unwrap(); if ch.write_all(&cmd).await.is_ok() {
if ch.write_all(&cmd).await.is_ok() { let _ = copy_bidirectional(&mut ch, &mut visitor).await;
let _ = copy_bidirectional(&mut ch, &mut visitor).await; }
} });
});
} else {
break;
}
} else { } else {
break; break;
} }

View File

@ -3,10 +3,8 @@ use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use std::fmt::Debug; use std::fmt::Debug;
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::{ use tokio::io::{AsyncRead, AsyncWrite};
io::{AsyncRead, AsyncWrite}, use tokio::net::ToSocketAddrs;
net::ToSocketAddrs,
};
#[async_trait] #[async_trait]
pub trait Transport: Debug + Send + Sync { pub trait Transport: Debug + Send + Sync {
@ -16,7 +14,7 @@ pub trait Transport: Debug + Send + Sync {
async fn new(config: &TransportConfig) -> Result<Box<Self>>; async fn new(config: &TransportConfig) -> Result<Box<Self>>;
async fn bind<T: ToSocketAddrs + Send + Sync>(&self, addr: T) -> Result<Self::Acceptor>; async fn bind<T: ToSocketAddrs + Send + Sync>(&self, addr: T) -> Result<Self::Acceptor>;
async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)>; async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)>;
async fn connect(&self, addr: &String) -> Result<Self::Stream>; async fn connect(&self, addr: &str) -> Result<Self::Stream>;
} }
mod tcp; mod tcp;

View File

@ -1,4 +1,5 @@
use crate::{config::TransportConfig, helper::set_tcp_keepalive}; use crate::config::TransportConfig;
use crate::helper::set_tcp_keepalive;
use super::Transport; use super::Transport;
use anyhow::Result; use anyhow::Result;
@ -28,7 +29,7 @@ impl Transport for TcpTransport {
Ok((s, addr)) Ok((s, addr))
} }
async fn connect(&self, addr: &String) -> Result<Self::Stream> { async fn connect(&self, addr: &str) -> Result<Self::Stream> {
let s = TcpStream::connect(addr).await?; let s = TcpStream::connect(addr).await?;
if let Err(e) = set_tcp_keepalive(&s) { if let Err(e) = set_tcp_keepalive(&s) {
error!( error!(

View File

@ -1,20 +1,14 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use super::Transport; use super::Transport;
use crate::{ use crate::config::{TlsConfig, TransportConfig};
config::{TlsConfig, TransportConfig}, use crate::helper::set_tcp_keepalive;
helper::set_tcp_keepalive,
};
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use async_trait::async_trait; use async_trait::async_trait;
use tokio::{ use tokio::fs;
fs, use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
net::{TcpListener, TcpStream, ToSocketAddrs}, use tokio_native_tls::native_tls::{self, Certificate, Identity};
}; use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream};
use tokio_native_tls::{
native_tls::{self, Certificate, Identity},
TlsAcceptor, TlsConnector, TlsStream,
};
use tracing::error; use tracing::error;
#[derive(Debug)] #[derive(Debug)]
@ -39,7 +33,7 @@ impl Transport for TlsTransport {
let connector = match config.trusted_root.as_ref() { let connector = match config.trusted_root.as_ref() {
Some(path) => { Some(path) => {
let s = fs::read_to_string(path).await?; let s = fs::read_to_string(path).await?;
let cert = Certificate::from_pem(&s.as_bytes())?; let cert = Certificate::from_pem(s.as_bytes())?;
let connector = native_tls::TlsConnector::builder() let connector = native_tls::TlsConnector::builder()
.add_root_certificate(cert) .add_root_certificate(cert)
.build()?; .build()?;
@ -74,7 +68,7 @@ impl Transport for TlsTransport {
Ok((conn, addr)) Ok((conn, addr))
} }
async fn connect(&self, addr: &String) -> Result<Self::Stream> { async fn connect(&self, addr: &str) -> Result<Self::Stream> {
let conn = TcpStream::connect(&addr).await?; let conn = TcpStream::connect(&addr).await?;
if let Err(e) = set_tcp_keepalive(&conn) { if let Err(e) = set_tcp_keepalive(&conn) {
error!( error!(