From a3e59a38eff2a4942fb42938a8243fe73aa62650 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 10 Jul 2023 12:43:48 -0500 Subject: [PATCH] Use registered io on Windows when possible (#905) --- Makefile | 4 +- go.mod | 1 + go.sum | 2 + interface.go | 7 + udp/conn.go | 4 + udp/udp_android.go | 5 + udp/udp_darwin.go | 5 + udp/udp_freebsd.go | 5 + udp/udp_generic.go | 8 +- udp/udp_linux.go | 9 +- udp/udp_rio_windows.go | 403 +++++++++++++++++++++++++++++++++++++++++ udp/udp_tester.go | 6 + udp/udp_windows.go | 21 ++- 13 files changed, 472 insertions(+), 8 deletions(-) create mode 100644 udp/udp_rio_windows.go diff --git a/Makefile b/Makefile index fecd889..7eaa07f 100644 --- a/Makefile +++ b/Makefile @@ -12,6 +12,8 @@ ifeq ($(OS),Windows_NT) GOISMIN := $(shell IF "$(GOVERSION)" GEQ "$(GOMINVERSION)" ECHO 1) NEBULA_CMD_SUFFIX = .exe NULL_FILE = nul + # RIO on windows does pointer stuff that makes go vet angry + VET_FLAGS = -unsafeptr=false else GOVERSION := $(shell go version | awk '{print substr($$3, 3)}') GOISMIN := $(shell expr "$(GOVERSION)" ">=" "$(GOMINVERSION)") @@ -143,7 +145,7 @@ build/nebula-%.zip: build/%/nebula.exe build/%/nebula-cert.exe cd build/$* && zip ../nebula-$*.zip nebula.exe nebula-cert.exe vet: - go vet -v ./... + go vet $(VET_FLAGS) -v ./... test: go test -v ./... diff --git a/go.mod b/go.mod index 52c2e92..9e92752 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( golang.org/x/sys v0.8.0 golang.org/x/term v0.8.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 + golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/protobuf v1.30.0 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index 452a1d2..ce47641 100644 --- a/go.sum +++ b/go.sum @@ -219,6 +219,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= +golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= +golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= diff --git a/interface.go b/interface.go index 82ab0f0..b5d43d2 100644 --- a/interface.go +++ b/interface.go @@ -413,6 +413,13 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { func (f *Interface) Close() error { f.closed.Store(true) + for _, u := range f.writers { + err := u.Close() + if err != nil { + f.l.WithError(err).Error("Error while closing udp socket") + } + } + // Release the tun device return f.inside.Close() } diff --git a/udp/conn.go b/udp/conn.go index 33520db..a2c24a1 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -26,6 +26,7 @@ type Conn interface { ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) WriteTo(b []byte, addr *Addr) error ReloadConfig(c *config.C) + Close() error } type NoopConn struct{} @@ -45,3 +46,6 @@ func (NoopConn) WriteTo(_ []byte, _ *Addr) error { func (NoopConn) ReloadConfig(_ *config.C) { return } +func (NoopConn) Close() error { + return nil +} diff --git a/udp/udp_android.go b/udp/udp_android.go index 08cde96..8d69074 100644 --- a/udp/udp_android.go +++ b/udp/udp_android.go @@ -8,9 +8,14 @@ import ( "net" "syscall" + "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + return NewGenericListener(l, ip, port, multi, batch) +} + func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 260ce44..afbf240 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -10,9 +10,14 @@ import ( "net" "syscall" + "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + return NewGenericListener(l, ip, port, multi, batch) +} + func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { diff --git a/udp/udp_freebsd.go b/udp/udp_freebsd.go index 920f91b..3c14fac 100644 --- a/udp/udp_freebsd.go +++ b/udp/udp_freebsd.go @@ -10,9 +10,14 @@ import ( "net" "syscall" + "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + return NewGenericListener(l, ip, port, multi, batch) +} + func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 490cc61..1dd6d1d 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -23,7 +23,9 @@ type GenericConn struct { l *logrus.Logger } -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +var _ Conn = &GenericConn{} + +func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { @@ -80,8 +82,8 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f // Just read one packet at a time n, rua, err := u.ReadFromUDP(buffer) if err != nil { - u.l.WithError(err).Error("Failed to read packets") - continue + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return } udpAddr.IP = rua.IP diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 60defaa..ca050bb 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -137,8 +137,8 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew for { n, err := read(msgs) if err != nil { - u.l.WithError(err).Error("Failed to read packets") - continue + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return } //metric.Update(int64(n)) @@ -262,6 +262,11 @@ func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error { return nil } +func (u *StdConn) Close() error { + //TODO: this will not interrupt the read loop + return syscall.Close(u.sysFd) +} + func NewUDPStatsEmitter(udpConns []Conn) func() { // Check if our kernel supports SO_MEMINFO before registering the gauges var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go new file mode 100644 index 0000000..31c1a55 --- /dev/null +++ b/udp/udp_rio_windows.go @@ -0,0 +1,403 @@ +//go:build !e2e_testing +// +build !e2e_testing + +// Inspired by https://git.zx2c4.com/wireguard-go/tree/conn/bind_windows.go + +package udp + +import ( + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "syscall" + "unsafe" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/conn/winrio" +) + +// Assert we meet the standard conn interface +var _ Conn = &RIOConn{} + +//go:linkname procyield runtime.procyield +func procyield(cycles uint32) + +const ( + packetsPerRing = 1024 + bytesPerPacket = 2048 - 32 + receiveSpins = 15 +) + +type ringPacket struct { + addr windows.RawSockaddrInet6 + data [bytesPerPacket]byte +} + +type ringBuffer struct { + packets uintptr + head, tail uint32 + id winrio.BufferId + iocp windows.Handle + isFull bool + cq winrio.Cq + mu sync.Mutex + overlapped windows.Overlapped +} + +type RIOConn struct { + isOpen atomic.Bool + l *logrus.Logger + sock windows.Handle + rx, tx ringBuffer + rq winrio.Rq + results [packetsPerRing]winrio.Result +} + +func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) { + if !winrio.Initialize() { + return nil, errors.New("could not initialize winrio") + } + + u := &RIOConn{l: l} + + addr := [16]byte{} + copy(addr[:], ip.To16()) + err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port}) + if err != nil { + return nil, fmt.Errorf("bind: %w", err) + } + + for i := 0; i < packetsPerRing; i++ { + err = u.insertReceiveRequest() + if err != nil { + return nil, fmt.Errorf("init rx ring: %w", err) + } + } + + u.isOpen.Store(true) + return u, nil +} + +func (u *RIOConn) bind(sa windows.Sockaddr) error { + var err error + u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP) + if err != nil { + return err + } + + // Enable v4 for this socket + syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + + err = u.rx.Open() + if err != nil { + return err + } + + err = u.tx.Open() + if err != nil { + return err + } + + u.rq, err = winrio.CreateRequestQueue(u.sock, packetsPerRing, 1, packetsPerRing, 1, u.rx.cq, u.tx.cq, 0) + if err != nil { + return err + } + + err = windows.Bind(u.sock, sa) + if err != nil { + return err + } + + return nil +} + +func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { + plaintext := make([]byte, MTU) + buffer := make([]byte, MTU) + h := &header.H{} + fwPacket := &firewall.Packet{} + udpAddr := &Addr{IP: make([]byte, 16)} + nb := make([]byte, 12, 12) + + for { + // Just read one packet at a time + n, rua, err := u.receive(buffer) + if err != nil { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + udpAddr.IP = rua.Addr[:] + p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port)) + p[0] = byte(rua.Port >> 8) + p[1] = byte(rua.Port) + r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + } +} + +func (u *RIOConn) insertReceiveRequest() error { + packet := u.rx.Push() + dataBuffer := &winrio.Buffer{ + Id: u.rx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.rx.packets), + Length: uint32(len(packet.data)), + } + addressBuffer := &winrio.Buffer{ + Id: u.rx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.rx.packets), + Length: uint32(unsafe.Sizeof(packet.addr)), + } + + return winrio.ReceiveEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) +} + +func (u *RIOConn) receive(buf []byte) (int, windows.RawSockaddrInet6, error) { + if !u.isOpen.Load() { + return 0, windows.RawSockaddrInet6{}, net.ErrClosed + } + + u.rx.mu.Lock() + defer u.rx.mu.Unlock() + + var err error + var count uint32 + var results [1]winrio.Result + +retry: + count = 0 + for tries := 0; count == 0 && tries < receiveSpins; tries++ { + if tries > 0 { + if !u.isOpen.Load() { + return 0, windows.RawSockaddrInet6{}, net.ErrClosed + } + procyield(1) + } + + count = winrio.DequeueCompletion(u.rx.cq, results[:]) + } + + if count == 0 { + err = winrio.Notify(u.rx.cq) + if err != nil { + return 0, windows.RawSockaddrInet6{}, err + } + var bytes uint32 + var key uintptr + var overlapped *windows.Overlapped + err = windows.GetQueuedCompletionStatus(u.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) + if err != nil { + return 0, windows.RawSockaddrInet6{}, err + } + + if !u.isOpen.Load() { + return 0, windows.RawSockaddrInet6{}, net.ErrClosed + } + + count = winrio.DequeueCompletion(u.rx.cq, results[:]) + if count == 0 { + return 0, windows.RawSockaddrInet6{}, io.ErrNoProgress + + } + } + + u.rx.Return(1) + err = u.insertReceiveRequest() + if err != nil { + return 0, windows.RawSockaddrInet6{}, err + } + + // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us + // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to + // attacker bandwidth, just like the rest of the receive path. + if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { + goto retry + } + + if results[0].Status != 0 { + return 0, windows.RawSockaddrInet6{}, windows.Errno(results[0].Status) + } + + packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) + ep := packet.addr + n := copy(buf, packet.data[:results[0].BytesTransferred]) + return n, ep, nil +} + +func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { + if !u.isOpen.Load() { + return net.ErrClosed + } + + if len(buf) > bytesPerPacket { + return io.ErrShortBuffer + } + + u.tx.mu.Lock() + defer u.tx.mu.Unlock() + + count := winrio.DequeueCompletion(u.tx.cq, u.results[:]) + if count == 0 && u.tx.isFull { + err := winrio.Notify(u.tx.cq) + if err != nil { + return err + } + + var bytes uint32 + var key uintptr + var overlapped *windows.Overlapped + err = windows.GetQueuedCompletionStatus(u.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) + if err != nil { + return err + } + + if !u.isOpen.Load() { + return net.ErrClosed + } + + count = winrio.DequeueCompletion(u.tx.cq, u.results[:]) + if count == 0 { + return io.ErrNoProgress + } + } + + if count > 0 { + u.tx.Return(count) + } + + packet := u.tx.Push() + packet.addr.Family = windows.AF_INET6 + p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port)) + p[0] = byte(addr.Port >> 8) + p[1] = byte(addr.Port) + copy(packet.addr.Addr[:], addr.IP.To16()) + copy(packet.data[:], buf) + + dataBuffer := &winrio.Buffer{ + Id: u.tx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.tx.packets), + Length: uint32(len(buf)), + } + + addressBuffer := &winrio.Buffer{ + Id: u.tx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.tx.packets), + Length: uint32(unsafe.Sizeof(packet.addr)), + } + + return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) +} + +func (u *RIOConn) LocalAddr() (*Addr, error) { + sa, err := windows.Getsockname(u.sock) + if err != nil { + return nil, err + } + + v6 := sa.(*windows.SockaddrInet6) + return &Addr{ + IP: v6.Addr[:], + Port: uint16(v6.Port), + }, nil +} + +func (u *RIOConn) Rebind() error { + return nil +} + +func (u *RIOConn) ReloadConfig(*config.C) {} + +func (u *RIOConn) Close() error { + if !u.isOpen.CompareAndSwap(true, false) { + return nil + } + + windows.PostQueuedCompletionStatus(u.rx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(u.tx.iocp, 0, 0, nil) + + u.rx.CloseAndZero() + u.tx.CloseAndZero() + if u.sock != 0 { + windows.CloseHandle(u.sock) + } + return nil +} + +func (ring *ringBuffer) Push() *ringPacket { + for ring.isFull { + panic("ring is full") + } + ret := (*ringPacket)(unsafe.Pointer(ring.packets + (uintptr(ring.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) + ring.tail += 1 + if ring.tail%packetsPerRing == ring.head%packetsPerRing { + ring.isFull = true + } + return ret +} + +func (ring *ringBuffer) Return(count uint32) { + if ring.head%packetsPerRing == ring.tail%packetsPerRing && !ring.isFull { + return + } + ring.head += count + ring.isFull = false +} + +func (ring *ringBuffer) CloseAndZero() { + if ring.cq != 0 { + winrio.CloseCompletionQueue(ring.cq) + ring.cq = 0 + } + + if ring.iocp != 0 { + windows.CloseHandle(ring.iocp) + ring.iocp = 0 + } + + if ring.id != 0 { + winrio.DeregisterBuffer(ring.id) + ring.id = 0 + } + + if ring.packets != 0 { + windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) + ring.packets = 0 + } + + ring.head = 0 + ring.tail = 0 + ring.isFull = false +} + +func (ring *ringBuffer) Open() error { + var err error + packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing + ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) + if err != nil { + return err + } + + ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) + if err != nil { + return err + } + + ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + if err != nil { + return err + } + + ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) + if err != nil { + return err + } + + return nil +} diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 4396c48..f03a69c 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -140,3 +140,9 @@ func (u *TesterConn) LocalAddr() (*Addr, error) { func (u *TesterConn) Rebind() error { return nil } + +func (u *TesterConn) Close() error { + close(u.RxPackets) + close(u.TxPackets) + return nil +} diff --git a/udp/udp_windows.go b/udp/udp_windows.go index 1456ede..ebcace6 100644 --- a/udp/udp_windows.go +++ b/udp/udp_windows.go @@ -3,14 +3,31 @@ package udp -// Windows support is primarily implemented in udp_generic, besides NewListenConfig - import ( "fmt" "net" "syscall" + + "github.com/sirupsen/logrus" ) +func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + if multi { + //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level + // The udp stack would need to be reworked to hide away the implementation differences between + // Windows and Linux + return nil, fmt.Errorf("multiple udp listeners not supported on windows") + } + + rc, err := NewRIOListener(l, ip, port) + if err == nil { + return rc, nil + } + + l.WithError(err).Error("Falling back to standard udp sockets") + return NewGenericListener(l, ip, port, multi, batch) +} + func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error {