mirror of https://github.com/slackhq/nebula.git
Use registered io on Windows when possible (#905)
This commit is contained in:
parent
8ba5d64dbc
commit
a3e59a38ef
4
Makefile
4
Makefile
|
@ -12,6 +12,8 @@ ifeq ($(OS),Windows_NT)
|
||||||
GOISMIN := $(shell IF "$(GOVERSION)" GEQ "$(GOMINVERSION)" ECHO 1)
|
GOISMIN := $(shell IF "$(GOVERSION)" GEQ "$(GOMINVERSION)" ECHO 1)
|
||||||
NEBULA_CMD_SUFFIX = .exe
|
NEBULA_CMD_SUFFIX = .exe
|
||||||
NULL_FILE = nul
|
NULL_FILE = nul
|
||||||
|
# RIO on windows does pointer stuff that makes go vet angry
|
||||||
|
VET_FLAGS = -unsafeptr=false
|
||||||
else
|
else
|
||||||
GOVERSION := $(shell go version | awk '{print substr($$3, 3)}')
|
GOVERSION := $(shell go version | awk '{print substr($$3, 3)}')
|
||||||
GOISMIN := $(shell expr "$(GOVERSION)" ">=" "$(GOMINVERSION)")
|
GOISMIN := $(shell expr "$(GOVERSION)" ">=" "$(GOMINVERSION)")
|
||||||
|
@ -143,7 +145,7 @@ build/nebula-%.zip: build/%/nebula.exe build/%/nebula-cert.exe
|
||||||
cd build/$* && zip ../nebula-$*.zip nebula.exe nebula-cert.exe
|
cd build/$* && zip ../nebula-$*.zip nebula.exe nebula-cert.exe
|
||||||
|
|
||||||
vet:
|
vet:
|
||||||
go vet -v ./...
|
go vet $(VET_FLAGS) -v ./...
|
||||||
|
|
||||||
test:
|
test:
|
||||||
go test -v ./...
|
go test -v ./...
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -26,6 +26,7 @@ require (
|
||||||
golang.org/x/sys v0.8.0
|
golang.org/x/sys v0.8.0
|
||||||
golang.org/x/term v0.8.0
|
golang.org/x/term v0.8.0
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||||
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
google.golang.org/protobuf v1.30.0
|
google.golang.org/protobuf v1.30.0
|
||||||
gopkg.in/yaml.v2 v2.4.0
|
gopkg.in/yaml.v2 v2.4.0
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -219,6 +219,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
|
||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
|
||||||
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||||
|
|
|
@ -413,6 +413,13 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||||
func (f *Interface) Close() error {
|
func (f *Interface) Close() error {
|
||||||
f.closed.Store(true)
|
f.closed.Store(true)
|
||||||
|
|
||||||
|
for _, u := range f.writers {
|
||||||
|
err := u.Close()
|
||||||
|
if err != nil {
|
||||||
|
f.l.WithError(err).Error("Error while closing udp socket")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Release the tun device
|
// Release the tun device
|
||||||
return f.inside.Close()
|
return f.inside.Close()
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ type Conn interface {
|
||||||
ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
|
ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
|
||||||
WriteTo(b []byte, addr *Addr) error
|
WriteTo(b []byte, addr *Addr) error
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
type NoopConn struct{}
|
type NoopConn struct{}
|
||||||
|
@ -45,3 +46,6 @@ func (NoopConn) WriteTo(_ []byte, _ *Addr) error {
|
||||||
func (NoopConn) ReloadConfig(_ *config.C) {
|
func (NoopConn) ReloadConfig(_ *config.C) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
func (NoopConn) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -8,9 +8,14 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
|
||||||
|
return NewGenericListener(l, ip, port, multi, batch)
|
||||||
|
}
|
||||||
|
|
||||||
func NewListenConfig(multi bool) net.ListenConfig {
|
func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
return net.ListenConfig{
|
return net.ListenConfig{
|
||||||
Control: func(network, address string, c syscall.RawConn) error {
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
|
|
|
@ -10,9 +10,14 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
|
||||||
|
return NewGenericListener(l, ip, port, multi, batch)
|
||||||
|
}
|
||||||
|
|
||||||
func NewListenConfig(multi bool) net.ListenConfig {
|
func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
return net.ListenConfig{
|
return net.ListenConfig{
|
||||||
Control: func(network, address string, c syscall.RawConn) error {
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
|
|
|
@ -10,9 +10,14 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
|
||||||
|
return NewGenericListener(l, ip, port, multi, batch)
|
||||||
|
}
|
||||||
|
|
||||||
func NewListenConfig(multi bool) net.ListenConfig {
|
func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
return net.ListenConfig{
|
return net.ListenConfig{
|
||||||
Control: func(network, address string, c syscall.RawConn) error {
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
|
|
|
@ -23,7 +23,9 @@ type GenericConn struct {
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
|
var _ Conn = &GenericConn{}
|
||||||
|
|
||||||
|
func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
|
||||||
lc := NewListenConfig(multi)
|
lc := NewListenConfig(multi)
|
||||||
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
|
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -80,8 +82,8 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f
|
||||||
// Just read one packet at a time
|
// Just read one packet at a time
|
||||||
n, rua, err := u.ReadFromUDP(buffer)
|
n, rua, err := u.ReadFromUDP(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.l.WithError(err).Error("Failed to read packets")
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
udpAddr.IP = rua.IP
|
udpAddr.IP = rua.IP
|
||||||
|
|
|
@ -137,8 +137,8 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
|
||||||
for {
|
for {
|
||||||
n, err := read(msgs)
|
n, err := read(msgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.l.WithError(err).Error("Failed to read packets")
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//metric.Update(int64(n))
|
//metric.Update(int64(n))
|
||||||
|
@ -262,6 +262,11 @@ func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) Close() error {
|
||||||
|
//TODO: this will not interrupt the read loop
|
||||||
|
return syscall.Close(u.sysFd)
|
||||||
|
}
|
||||||
|
|
||||||
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
||||||
// Check if our kernel supports SO_MEMINFO before registering the gauges
|
// Check if our kernel supports SO_MEMINFO before registering the gauges
|
||||||
var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
|
var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
|
||||||
|
|
|
@ -0,0 +1,403 @@
|
||||||
|
//go:build !e2e_testing
|
||||||
|
// +build !e2e_testing
|
||||||
|
|
||||||
|
// Inspired by https://git.zx2c4.com/wireguard-go/tree/conn/bind_windows.go
|
||||||
|
|
||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
"golang.zx2c4.com/wireguard/conn/winrio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Assert we meet the standard conn interface
|
||||||
|
var _ Conn = &RIOConn{}
|
||||||
|
|
||||||
|
//go:linkname procyield runtime.procyield
|
||||||
|
func procyield(cycles uint32)
|
||||||
|
|
||||||
|
const (
|
||||||
|
packetsPerRing = 1024
|
||||||
|
bytesPerPacket = 2048 - 32
|
||||||
|
receiveSpins = 15
|
||||||
|
)
|
||||||
|
|
||||||
|
type ringPacket struct {
|
||||||
|
addr windows.RawSockaddrInet6
|
||||||
|
data [bytesPerPacket]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type ringBuffer struct {
|
||||||
|
packets uintptr
|
||||||
|
head, tail uint32
|
||||||
|
id winrio.BufferId
|
||||||
|
iocp windows.Handle
|
||||||
|
isFull bool
|
||||||
|
cq winrio.Cq
|
||||||
|
mu sync.Mutex
|
||||||
|
overlapped windows.Overlapped
|
||||||
|
}
|
||||||
|
|
||||||
|
type RIOConn struct {
|
||||||
|
isOpen atomic.Bool
|
||||||
|
l *logrus.Logger
|
||||||
|
sock windows.Handle
|
||||||
|
rx, tx ringBuffer
|
||||||
|
rq winrio.Rq
|
||||||
|
results [packetsPerRing]winrio.Result
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) {
|
||||||
|
if !winrio.Initialize() {
|
||||||
|
return nil, errors.New("could not initialize winrio")
|
||||||
|
}
|
||||||
|
|
||||||
|
u := &RIOConn{l: l}
|
||||||
|
|
||||||
|
addr := [16]byte{}
|
||||||
|
copy(addr[:], ip.To16())
|
||||||
|
err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("bind: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < packetsPerRing; i++ {
|
||||||
|
err = u.insertReceiveRequest()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("init rx ring: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
u.isOpen.Store(true)
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RIOConn) bind(sa windows.Sockaddr) error {
|
||||||
|
var err error
|
||||||
|
u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable v4 for this socket
|
||||||
|
syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
|
||||||
|
|
||||||
|
err = u.rx.Open()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = u.tx.Open()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
u.rq, err = winrio.CreateRequestQueue(u.sock, packetsPerRing, 1, packetsPerRing, 1, u.rx.cq, u.tx.cq, 0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = windows.Bind(u.sock, sa)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
|
||||||
|
plaintext := make([]byte, MTU)
|
||||||
|
buffer := make([]byte, MTU)
|
||||||
|
h := &header.H{}
|
||||||
|
fwPacket := &firewall.Packet{}
|
||||||
|
udpAddr := &Addr{IP: make([]byte, 16)}
|
||||||
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
|
for {
|
||||||
|
// Just read one packet at a time
|
||||||
|
n, rua, err := u.receive(buffer)
|
||||||
|
if err != nil {
|
||||||
|
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
udpAddr.IP = rua.Addr[:]
|
||||||
|
p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port))
|
||||||
|
p[0] = byte(rua.Port >> 8)
|
||||||
|
p[1] = byte(rua.Port)
|
||||||
|
r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RIOConn) insertReceiveRequest() error {
|
||||||
|
packet := u.rx.Push()
|
||||||
|
dataBuffer := &winrio.Buffer{
|
||||||
|
Id: u.rx.id,
|
||||||
|
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.rx.packets),
|
||||||
|
Length: uint32(len(packet.data)),
|
||||||
|
}
|
||||||
|
addressBuffer := &winrio.Buffer{
|
||||||
|
Id: u.rx.id,
|
||||||
|
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.rx.packets),
|
||||||
|
Length: uint32(unsafe.Sizeof(packet.addr)),
|
||||||
|
}
|
||||||
|
|
||||||
|
return winrio.ReceiveEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RIOConn) receive(buf []byte) (int, windows.RawSockaddrInet6, error) {
|
||||||
|
if !u.isOpen.Load() {
|
||||||
|
return 0, windows.RawSockaddrInet6{}, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
u.rx.mu.Lock()
|
||||||
|
defer u.rx.mu.Unlock()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var count uint32
|
||||||
|
var results [1]winrio.Result
|
||||||
|
|
||||||
|
retry:
|
||||||
|
count = 0
|
||||||
|
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
|
||||||
|
if tries > 0 {
|
||||||
|
if !u.isOpen.Load() {
|
||||||
|
return 0, windows.RawSockaddrInet6{}, net.ErrClosed
|
||||||
|
}
|
||||||
|
procyield(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
count = winrio.DequeueCompletion(u.rx.cq, results[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
if count == 0 {
|
||||||
|
err = winrio.Notify(u.rx.cq)
|
||||||
|
if err != nil {
|
||||||
|
return 0, windows.RawSockaddrInet6{}, err
|
||||||
|
}
|
||||||
|
var bytes uint32
|
||||||
|
var key uintptr
|
||||||
|
var overlapped *windows.Overlapped
|
||||||
|
err = windows.GetQueuedCompletionStatus(u.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
|
||||||
|
if err != nil {
|
||||||
|
return 0, windows.RawSockaddrInet6{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !u.isOpen.Load() {
|
||||||
|
return 0, windows.RawSockaddrInet6{}, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
count = winrio.DequeueCompletion(u.rx.cq, results[:])
|
||||||
|
if count == 0 {
|
||||||
|
return 0, windows.RawSockaddrInet6{}, io.ErrNoProgress
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
u.rx.Return(1)
|
||||||
|
err = u.insertReceiveRequest()
|
||||||
|
if err != nil {
|
||||||
|
return 0, windows.RawSockaddrInet6{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
|
||||||
|
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
|
||||||
|
// attacker bandwidth, just like the rest of the receive path.
|
||||||
|
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
|
|
||||||
|
if results[0].Status != 0 {
|
||||||
|
return 0, windows.RawSockaddrInet6{}, windows.Errno(results[0].Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
|
||||||
|
ep := packet.addr
|
||||||
|
n := copy(buf, packet.data[:results[0].BytesTransferred])
|
||||||
|
return n, ep, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
|
||||||
|
if !u.isOpen.Load() {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(buf) > bytesPerPacket {
|
||||||
|
return io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
u.tx.mu.Lock()
|
||||||
|
defer u.tx.mu.Unlock()
|
||||||
|
|
||||||
|
count := winrio.DequeueCompletion(u.tx.cq, u.results[:])
|
||||||
|
if count == 0 && u.tx.isFull {
|
||||||
|
err := winrio.Notify(u.tx.cq)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var bytes uint32
|
||||||
|
var key uintptr
|
||||||
|
var overlapped *windows.Overlapped
|
||||||
|
err = windows.GetQueuedCompletionStatus(u.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !u.isOpen.Load() {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
count = winrio.DequeueCompletion(u.tx.cq, u.results[:])
|
||||||
|
if count == 0 {
|
||||||
|
return io.ErrNoProgress
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if count > 0 {
|
||||||
|
u.tx.Return(count)
|
||||||
|
}
|
||||||
|
|
||||||
|
packet := u.tx.Push()
|
||||||
|
packet.addr.Family = windows.AF_INET6
|
||||||
|
p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port))
|
||||||
|
p[0] = byte(addr.Port >> 8)
|
||||||
|
p[1] = byte(addr.Port)
|
||||||
|
copy(packet.addr.Addr[:], addr.IP.To16())
|
||||||
|
copy(packet.data[:], buf)
|
||||||
|
|
||||||
|
dataBuffer := &winrio.Buffer{
|
||||||
|
Id: u.tx.id,
|
||||||
|
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.tx.packets),
|
||||||
|
Length: uint32(len(buf)),
|
||||||
|
}
|
||||||
|
|
||||||
|
addressBuffer := &winrio.Buffer{
|
||||||
|
Id: u.tx.id,
|
||||||
|
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.tx.packets),
|
||||||
|
Length: uint32(unsafe.Sizeof(packet.addr)),
|
||||||
|
}
|
||||||
|
|
||||||
|
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RIOConn) LocalAddr() (*Addr, error) {
|
||||||
|
sa, err := windows.Getsockname(u.sock)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v6 := sa.(*windows.SockaddrInet6)
|
||||||
|
return &Addr{
|
||||||
|
IP: v6.Addr[:],
|
||||||
|
Port: uint16(v6.Port),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RIOConn) Rebind() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *RIOConn) ReloadConfig(*config.C) {}
|
||||||
|
|
||||||
|
func (u *RIOConn) Close() error {
|
||||||
|
if !u.isOpen.CompareAndSwap(true, false) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
windows.PostQueuedCompletionStatus(u.rx.iocp, 0, 0, nil)
|
||||||
|
windows.PostQueuedCompletionStatus(u.tx.iocp, 0, 0, nil)
|
||||||
|
|
||||||
|
u.rx.CloseAndZero()
|
||||||
|
u.tx.CloseAndZero()
|
||||||
|
if u.sock != 0 {
|
||||||
|
windows.CloseHandle(u.sock)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ring *ringBuffer) Push() *ringPacket {
|
||||||
|
for ring.isFull {
|
||||||
|
panic("ring is full")
|
||||||
|
}
|
||||||
|
ret := (*ringPacket)(unsafe.Pointer(ring.packets + (uintptr(ring.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
|
||||||
|
ring.tail += 1
|
||||||
|
if ring.tail%packetsPerRing == ring.head%packetsPerRing {
|
||||||
|
ring.isFull = true
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ring *ringBuffer) Return(count uint32) {
|
||||||
|
if ring.head%packetsPerRing == ring.tail%packetsPerRing && !ring.isFull {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ring.head += count
|
||||||
|
ring.isFull = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ring *ringBuffer) CloseAndZero() {
|
||||||
|
if ring.cq != 0 {
|
||||||
|
winrio.CloseCompletionQueue(ring.cq)
|
||||||
|
ring.cq = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if ring.iocp != 0 {
|
||||||
|
windows.CloseHandle(ring.iocp)
|
||||||
|
ring.iocp = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if ring.id != 0 {
|
||||||
|
winrio.DeregisterBuffer(ring.id)
|
||||||
|
ring.id = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if ring.packets != 0 {
|
||||||
|
windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
|
||||||
|
ring.packets = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
ring.head = 0
|
||||||
|
ring.tail = 0
|
||||||
|
ring.isFull = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ring *ringBuffer) Open() error {
|
||||||
|
var err error
|
||||||
|
packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
|
||||||
|
ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -140,3 +140,9 @@ func (u *TesterConn) LocalAddr() (*Addr, error) {
|
||||||
func (u *TesterConn) Rebind() error {
|
func (u *TesterConn) Rebind() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *TesterConn) Close() error {
|
||||||
|
close(u.RxPackets)
|
||||||
|
close(u.TxPackets)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -3,14 +3,31 @@
|
||||||
|
|
||||||
package udp
|
package udp
|
||||||
|
|
||||||
// Windows support is primarily implemented in udp_generic, besides NewListenConfig
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
|
||||||
|
if multi {
|
||||||
|
//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
|
||||||
|
// The udp stack would need to be reworked to hide away the implementation differences between
|
||||||
|
// Windows and Linux
|
||||||
|
return nil, fmt.Errorf("multiple udp listeners not supported on windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
rc, err := NewRIOListener(l, ip, port)
|
||||||
|
if err == nil {
|
||||||
|
return rc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
l.WithError(err).Error("Falling back to standard udp sockets")
|
||||||
|
return NewGenericListener(l, ip, port, multi, batch)
|
||||||
|
}
|
||||||
|
|
||||||
func NewListenConfig(multi bool) net.ListenConfig {
|
func NewListenConfig(multi bool) net.ListenConfig {
|
||||||
return net.ListenConfig{
|
return net.ListenConfig{
|
||||||
Control: func(network, address string, c syscall.RawConn) error {
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
|
|
Loading…
Reference in New Issue