diff --git a/cidr/tree4.go b/cidr/tree4.go index 28d0e78..0839c90 100644 --- a/cidr/tree4.go +++ b/cidr/tree4.go @@ -13,8 +13,14 @@ type Node struct { value interface{} } +type entry struct { + CIDR *net.IPNet + Value *interface{} +} + type Tree4 struct { root *Node + list []entry } const ( @@ -24,6 +30,7 @@ const ( func NewTree4() *Tree4 { tree := new(Tree4) tree.root = &Node{} + tree.list = []entry{} return tree } @@ -53,6 +60,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) { // We already have this range so update the value if next != nil { + addCIDR := cidr.String() + for i, v := range tree.list { + if addCIDR == v.CIDR.String() { + tree.list = append(tree.list[:i], tree.list[i+1:]...) + break + } + } + + tree.list = append(tree.list, entry{CIDR: cidr, Value: &val}) node.value = val return } @@ -74,9 +90,10 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) { // Final node marks our cidr, set the value node.value = val + tree.list = append(tree.list, entry{CIDR: cidr, Value: &val}) } -// Finds the first match, which may be the least specific +// Contains finds the first match, which may be the least specific func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) { bit := startbit node := tree.root @@ -99,7 +116,7 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) { return value } -// Finds the most specific match +// MostSpecificContains finds the most specific match func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) { bit := startbit node := tree.root @@ -121,7 +138,7 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) { return value } -// Finds the most specific match +// Match finds the most specific match func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) { bit := startbit node := tree.root @@ -143,3 +160,8 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) { } return value } + +// List will return all CIDRs and their current values. Do not modify the contents! +func (tree *Tree4) List() []entry { + return tree.list +} diff --git a/cidr/tree4_test.go b/cidr/tree4_test.go index 07f2b0a..dce8d54 100644 --- a/cidr/tree4_test.go +++ b/cidr/tree4_test.go @@ -8,6 +8,20 @@ import ( "github.com/stretchr/testify/assert" ) +func TestCIDRTree_List(t *testing.T) { + tree := NewTree4() + tree.AddCIDR(Parse("1.0.0.0/16"), "1") + tree.AddCIDR(Parse("1.0.0.0/8"), "2") + tree.AddCIDR(Parse("1.0.0.0/16"), "3") + tree.AddCIDR(Parse("1.0.0.0/16"), "4") + list := tree.List() + assert.Len(t, list, 2) + assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String()) + assert.Equal(t, "2", *list[0].Value) + assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String()) + assert.Equal(t, "4", *list[1].Value) +} + func TestCIDRTree_Contains(t *testing.T) { tree := NewTree4() tree.AddCIDR(Parse("1.0.0.0/8"), "1") diff --git a/examples/config.yml b/examples/config.yml index db5d0e3..9356b3a 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -223,6 +223,10 @@ tun: # metric: 100 # install: true + # On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of + # in nebula configuration files. Default false, not reloadable. + #use_system_route_table: false + # TODO # Configure logging level logging: diff --git a/overlay/tun.go b/overlay/tun.go index 3da50b8..5eccec9 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -35,6 +35,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd * c.GetInt("tun.mtu", DefaultMTU), routes, c.GetInt("tun.tx_queue", 500), + c.GetBool("tun.use_system_route_table", false), ) default: @@ -46,6 +47,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd * routes, c.GetInt("tun.tx_queue", 500), routines > 1, + c.GetBool("tun.use_system_route_table", false), ) } } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 321aec8..c731d78 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -22,7 +22,7 @@ type tun struct { l *logrus.Logger } -func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) { +func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) { routeTree, err := makeRouteTree(l, routes, false) if err != nil { return nil, err @@ -41,7 +41,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes }, nil } -func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { +func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 6320570..fd3429d 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -77,7 +77,7 @@ type ifreqMTU struct { pad [8]byte } -func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) { +func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { routeTree, err := makeRouteTree(l, routes, false) if err != nil { return nil, err @@ -170,7 +170,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) { +func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 1054228..99cbdb0 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -38,11 +38,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) { +func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { routeTree, err := makeRouteTree(l, routes, false) if err != nil { return nil, err diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 59c190e..26f34ec 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -23,11 +23,11 @@ type tun struct { routeTree *cidr.Tree4 } -func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { +func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) { +func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) { routeTree, err := makeRouteTree(l, routes, false) if err != nil { return nil, err diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 932b585..7833186 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,11 +4,13 @@ package overlay import ( + "bytes" "fmt" "io" "net" "os" "strings" + "sync/atomic" "unsafe" "github.com/sirupsen/logrus" @@ -26,9 +28,13 @@ type tun struct { MaxMTU int DefaultMTU int TXQueueLen int - Routes []Route - routeTree *cidr.Tree4 - l *logrus.Logger + + Routes []Route + routeTree atomic.Pointer[cidr.Tree4] + routeChan chan struct{} + useSystemRoutes bool + + l *logrus.Logger } type ifReq struct { @@ -63,7 +69,7 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) { +func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, useSystemRoutes bool) (*tun, error) { routeTree, err := makeRouteTree(l, routes, true) if err != nil { return nil, err @@ -71,7 +77,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - return &tun{ + t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), Device: "tun0", @@ -79,12 +85,14 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in DefaultMTU: defaultMTU, TXQueueLen: txQueueLen, Routes: routes, - routeTree: routeTree, + useSystemRoutes: useSystemRoutes, l: l, - }, nil + } + t.routeTree.Store(routeTree) + return t, nil } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool) (*tun, error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool, useSystemRoutes bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err @@ -119,7 +127,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int return nil, err } - return &tun{ + t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), Device: name, @@ -128,9 +136,11 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int DefaultMTU: defaultMTU, TXQueueLen: txQueueLen, Routes: routes, - routeTree: routeTree, + useSystemRoutes: useSystemRoutes, l: l, - }, nil + } + t.routeTree.Store(routeTree) + return t, nil } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { @@ -152,7 +162,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) + r := t.routeTree.Load().MostSpecificContains(ip) if r != nil { return r.(iputil.VpnIp) } @@ -183,16 +193,20 @@ func (t *tun) Write(b []byte) (int, error) { } } -func (t tun) deviceBytes() (o [16]byte) { +func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) } return } -func (t tun) Activate() error { +func (t *tun) Activate() error { devName := t.deviceBytes() + if t.useSystemRoutes { + t.watchRoutes() + } + var addr, mask [4]byte copy(addr[:], t.cidr.IP.To4()) @@ -318,7 +332,7 @@ func (t *tun) Name() string { return t.Device } -func (t tun) advMSS(r Route) int { +func (t *tun) advMSS(r Route) int { mtu := r.MTU if r.MTU == 0 { mtu = t.DefaultMTU @@ -330,3 +344,83 @@ func (t tun) advMSS(r Route) int { } return 0 } + +func (t *tun) watchRoutes() { + rch := make(chan netlink.RouteUpdate) + doneChan := make(chan struct{}) + + if err := netlink.RouteSubscribe(rch, doneChan); err != nil { + t.l.WithError(err).Errorf("failed to subscribe to system route changes") + return + } + + t.routeChan = doneChan + + go func() { + for { + select { + case r := <-rch: + t.updateRoutes(r) + case <-doneChan: + // netlink.RouteSubscriber will close the rch for us + return + } + } + }() +} + +func (t *tun) updateRoutes(r netlink.RouteUpdate) { + if r.Gw == nil { + // Not a gateway route, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route") + return + } + + if !t.cidr.Contains(r.Gw) { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not in our network") + return + } + + if x := r.Dst.IP.To4(); x == nil { + // Nebula only handles ipv4 on the overlay currently + t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4") + return + } + + newTree := cidr.NewTree4() + if r.Type == unix.RTM_NEWROUTE { + for _, oldR := range t.routeTree.Load().List() { + newTree.AddCIDR(oldR.CIDR, oldR.Value) + } + + t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route") + newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw)) + + } else { + gw := iputil.Ip2VpnIp(r.Gw) + for _, oldR := range t.routeTree.Load().List() { + if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw { + // This is the record to delete + t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") + continue + } + + newTree.AddCIDR(oldR.CIDR, oldR.Value) + } + } + + t.routeTree.Store(newTree) +} + +func (t *tun) Close() error { + if t.routeChan != nil { + close(t.routeChan) + } + + if t.ReadWriteCloser != nil { + t.ReadWriteCloser.Close() + } + + return nil +} diff --git a/overlay/tun_linux_test.go b/overlay/tun_linux_test.go index 6c2043d..1c1842d 100644 --- a/overlay/tun_linux_test.go +++ b/overlay/tun_linux_test.go @@ -7,19 +7,19 @@ import "testing" var runAdvMSSTests = []struct { name string - tun tun + tun *tun r Route expected int }{ // Standard case, default MTU is the device max MTU - {"default", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0}, - {"default-min", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0}, - {"default-low", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160}, + {"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0}, + {"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0}, + {"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160}, // Case where we have a route MTU set higher than the default - {"route", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400}, - {"route-min", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400}, - {"route-high", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0}, + {"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400}, + {"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400}, + {"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0}, } func TestTunAdvMSS(t *testing.T) { diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 38c11a6..3a49dcb 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -25,7 +25,7 @@ type TestTun struct { TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool, _ bool) (*TestTun, error) { routeTree, err := makeRouteTree(l, routes, false) if err != nil { return nil, err @@ -42,7 +42,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes }, nil } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*TestTun, error) { +func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index e35e98b..57d90cb 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -14,11 +14,11 @@ import ( "github.com/sirupsen/logrus" ) -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (Device, error) { +func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (Device, error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (Device, error) { useWintun := true if err := checkWinTunExists(); err != nil { l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")