rathole/src/protocol.rs

138 lines
3.9 KiB
Rust

pub const HASH_WIDTH_IN_BYTES: usize = 32;
use anyhow::{Context, Result};
use bincode;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
type ProtocolVersion = u8;
const PROTO_V0: u8 = 0u8;
pub const CURRENT_PROTO_VRESION: ProtocolVersion = PROTO_V0;
pub type Digest = [u8; HASH_WIDTH_IN_BYTES];
#[derive(Deserialize, Serialize, Debug)]
pub enum Hello {
ControlChannelHello(ProtocolVersion, Digest), // sha256sum(service name) or a nonce
DataChannelHello(ProtocolVersion, Digest), // token provided by CreateDataChannel
}
#[derive(Deserialize, Serialize, Debug)]
pub struct Auth(pub Digest);
#[derive(Deserialize, Serialize, Debug)]
pub enum Ack {
Ok,
ServiceNotExist,
AuthFailed,
}
impl std::fmt::Display for Ack {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
Ack::Ok => "Ok",
Ack::ServiceNotExist => "Service not exist",
Ack::AuthFailed => "Incorrect token",
}
)
}
}
#[derive(Deserialize, Serialize, Debug)]
pub enum ControlChannelCmd {
CreateDataChannel,
}
#[derive(Deserialize, Serialize, Debug)]
pub enum DataChannelCmd {
StartForward,
}
pub fn digest(data: &[u8]) -> Digest {
let d = ring::digest::digest(&ring::digest::SHA256, data);
d.as_ref().try_into().unwrap()
}
struct PacketLength {
hello: usize,
ack: usize,
auth: usize,
c_cmd: usize,
d_cmd: usize,
}
impl PacketLength {
pub fn new() -> PacketLength {
let username = "default";
let d = digest(username.as_bytes());
let hello = bincode::serialized_size(&Hello::ControlChannelHello(CURRENT_PROTO_VRESION, d))
.unwrap() as usize;
let c_cmd =
bincode::serialized_size(&ControlChannelCmd::CreateDataChannel).unwrap() as usize;
let d_cmd = bincode::serialized_size(&DataChannelCmd::StartForward).unwrap() as usize;
let ack = Ack::Ok;
let ack = bincode::serialized_size(&ack).unwrap() as usize;
let auth = bincode::serialized_size(&Auth(d)).unwrap() as usize;
PacketLength {
hello,
ack,
auth,
c_cmd,
d_cmd,
}
}
}
lazy_static! {
static ref PACKET_LEN: PacketLength = PacketLength::new();
}
pub async fn read_hello(conn: &mut TcpStream) -> Result<Hello> {
let mut buf = vec![0u8; PACKET_LEN.hello];
conn.read_exact(&mut buf)
.await
.with_context(|| "Failed to read hello")?;
let hello = bincode::deserialize(&buf).with_context(|| "Failed to deserialize hello")?;
Ok(hello)
}
pub async fn read_auth(conn: &mut TcpStream) -> Result<Auth> {
let mut buf = vec![0u8; PACKET_LEN.auth];
conn.read_exact(&mut buf)
.await
.with_context(|| "Failed to read auth")?;
bincode::deserialize(&buf).with_context(|| "Failed to deserialize auth")
}
pub async fn read_ack(conn: &mut TcpStream) -> Result<Ack> {
let mut bytes = vec![0u8; PACKET_LEN.ack];
conn.read_exact(&mut bytes)
.await
.with_context(|| "Failed to read ack")?;
bincode::deserialize(&bytes).with_context(|| "Failed to deserialize ack")
}
pub async fn read_control_cmd(conn: &mut TcpStream) -> Result<ControlChannelCmd> {
let mut bytes = vec![0u8; PACKET_LEN.c_cmd];
conn.read_exact(&mut bytes)
.await
.with_context(|| "Failed to read control cmd")?;
bincode::deserialize(&bytes).with_context(|| "Failed to deserialize control cmd")
}
pub async fn read_data_cmd(conn: &mut TcpStream) -> Result<DataChannelCmd> {
let mut bytes = vec![0u8; PACKET_LEN.d_cmd];
conn.read_exact(&mut bytes)
.await
.with_context(|| "Failed to read data cmd")?;
bincode::deserialize(&bytes).with_context(|| "Failed to deserialize data cmd")
}