rathole/src/transport/websocket.rs

251 lines
7.9 KiB
Rust

use core::result::Result;
use std::io::{Error, ErrorKind};
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use super::{AddrMaybeCached, SocketOpts, TcpTransport, TlsTransport, Transport};
use crate::config::TransportConfig;
use anyhow::anyhow;
use async_trait::async_trait;
use bytes::Bytes;
use futures_core::stream::Stream;
use futures_sink::Sink;
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio_native_tls::TlsStream;
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
use tokio_tungstenite::{accept_async_with_config, client_async_with_config};
use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream};
use tokio_util::io::StreamReader;
use url::Url;
#[derive(Debug)]
enum TransportStream {
Insecure(TcpStream),
Secure(TlsStream<TcpStream>),
}
impl TransportStream {
fn get_tcpstream(&self) -> &TcpStream {
match self {
TransportStream::Insecure(s) => s,
TransportStream::Secure(s) => s.get_ref().get_ref().get_ref(),
}
}
}
impl AsyncRead for TransportStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
TransportStream::Insecure(s) => Pin::new(s).poll_read(cx, buf),
TransportStream::Secure(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for TransportStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
TransportStream::Insecure(s) => Pin::new(s).poll_write(cx, buf),
TransportStream::Secure(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
TransportStream::Insecure(s) => Pin::new(s).poll_flush(cx),
TransportStream::Secure(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
TransportStream::Insecure(s) => Pin::new(s).poll_shutdown(cx),
TransportStream::Secure(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
#[derive(Debug)]
struct StreamWrapper {
inner: WebSocketStream<TransportStream>,
}
impl Stream for StreamWrapper {
type Item = Result<Bytes, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.get_mut().inner).poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(err))) => {
Poll::Ready(Some(Err(Error::new(ErrorKind::Other, err))))
}
Poll::Ready(Some(Ok(res))) => {
if let Message::Binary(b) = res {
Poll::Ready(Some(Ok(Bytes::from(b))))
} else {
Poll::Ready(Some(Err(Error::new(
ErrorKind::InvalidData,
"unexpected frame",
))))
}
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
#[derive(Debug)]
pub struct WebsocketTunnel {
inner: StreamReader<StreamWrapper, Bytes>,
}
impl AsyncRead for WebsocketTunnel {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
}
}
impl AsyncBufRead for WebsocketTunnel {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
Pin::new(&mut self.get_mut().inner).poll_fill_buf(cx)
}
fn consume(self: Pin<&mut Self>, amt: usize) {
Pin::new(&mut self.get_mut().inner).consume(amt)
}
}
impl AsyncWrite for WebsocketTunnel {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let sw = self.get_mut().inner.get_mut();
ready!(Pin::new(&mut sw.inner)
.poll_ready(cx)
.map_err(|err| Error::new(ErrorKind::Other, err)))?;
match Pin::new(&mut sw.inner).start_send(Message::Binary(buf.to_vec())) {
Ok(()) => Poll::Ready(Ok(buf.len())),
Err(e) => Poll::Ready(Err(Error::new(ErrorKind::Other, e))),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Pin::new(&mut self.get_mut().inner.get_mut().inner)
.poll_flush(cx)
.map_err(|err| Error::new(ErrorKind::Other, err))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Pin::new(&mut self.get_mut().inner.get_mut().inner)
.poll_close(cx)
.map_err(|err| Error::new(ErrorKind::Other, err))
}
}
#[derive(Debug)]
enum SubTransport {
Secure(TlsTransport),
Insecure(TcpTransport),
}
#[derive(Debug)]
pub struct WebsocketTransport {
sub: SubTransport,
conf: WebSocketConfig,
}
#[async_trait]
impl Transport for WebsocketTransport {
type Acceptor = TcpListener;
type RawStream = TcpStream;
type Stream = WebsocketTunnel;
fn new(config: &TransportConfig) -> anyhow::Result<Self> {
let wsconfig = config
.websocket
.as_ref()
.ok_or_else(|| anyhow!("Missing websocket config"))?;
let conf = WebSocketConfig {
write_buffer_size: 0,
..WebSocketConfig::default()
};
let sub = match wsconfig.tls {
true => SubTransport::Secure(TlsTransport::new(config)?),
false => SubTransport::Insecure(TcpTransport::new(config)?),
};
Ok(WebsocketTransport { sub, conf })
}
fn hint(conn: &Self::Stream, opt: SocketOpts) {
opt.apply(conn.inner.get_ref().inner.get_ref().get_tcpstream())
}
async fn bind<A: ToSocketAddrs + Send + Sync>(
&self,
addr: A,
) -> anyhow::Result<Self::Acceptor> {
TcpListener::bind(addr).await.map_err(Into::into)
}
async fn accept(&self, a: &Self::Acceptor) -> anyhow::Result<(Self::RawStream, SocketAddr)> {
let (s, addr) = match &self.sub {
SubTransport::Insecure(t) => t.accept(a).await?,
SubTransport::Secure(t) => t.accept(a).await?,
};
Ok((s, addr))
}
async fn handshake(&self, conn: Self::RawStream) -> anyhow::Result<Self::Stream> {
let tsream = match &self.sub {
SubTransport::Insecure(t) => TransportStream::Insecure(t.handshake(conn).await?),
SubTransport::Secure(t) => TransportStream::Secure(t.handshake(conn).await?),
};
let wsstream = accept_async_with_config(tsream, Some(self.conf)).await?;
let tun = WebsocketTunnel {
inner: StreamReader::new(StreamWrapper { inner: wsstream }),
};
Ok(tun)
}
async fn connect(&self, addr: &AddrMaybeCached) -> anyhow::Result<Self::Stream> {
let u = format!("ws://{}", &addr.addr.as_str());
let url = Url::parse(&u).unwrap();
let tstream = match &self.sub {
SubTransport::Insecure(t) => TransportStream::Insecure(t.connect(addr).await?),
SubTransport::Secure(t) => TransportStream::Secure(t.connect(addr).await?),
};
let (wsstream, _) = client_async_with_config(url, tstream, Some(self.conf))
.await
.expect("failed to connect");
let tun = WebsocketTunnel {
inner: StreamReader::new(StreamWrapper { inner: wsstream }),
};
Ok(tun)
}
}