diff --git a/hostmap.go b/hostmap.go index 82e7016..6b726c1 100644 --- a/hostmap.go +++ b/hostmap.go @@ -15,7 +15,6 @@ import ( "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) @@ -36,7 +35,6 @@ type HostMap struct { Hosts map[iputil.VpnIp]*HostInfo preferredRanges []*net.IPNet vpnCIDR *net.IPNet - unsafeRoutes *cidr.Tree4 metricsEnabled bool l *logrus.Logger } @@ -99,7 +97,6 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang Hosts: h, preferredRanges: preferredRanges, vpnCIDR: vpnCIDR, - unsafeRoutes: cidr.NewTree4(), l: l, } return &m @@ -333,15 +330,6 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*Host return nil, errors.New("unable to find host") } -func (hm *HostMap) queryUnsafeRoute(ip iputil.VpnIp) iputil.VpnIp { - r := hm.unsafeRoutes.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } else { - return 0 - } -} - // We already have the hm Lock when this is called, so make sure to not call // any other methods that might try to grab it again func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) { @@ -409,13 +397,6 @@ func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) { } } -func (hm *HostMap) addUnsafeRoutes(routes *[]overlay.Route) { - for _, r := range *routes { - hm.l.WithField("cidr", r.Cidr).WithField("via", r.Via).Warn("Adding UNSAFE Route") - hm.unsafeRoutes.AddCIDR(r.Cidr, iputil.Ip2VpnIp(*r.Via)) - } -} - func (i *HostInfo) BindConnectionState(cs *ConnectionState) { i.ConnectionState = cs } diff --git a/inside.go b/inside.go index 7ab083a..988bb65 100644 --- a/inside.go +++ b/inside.go @@ -72,7 +72,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { //TODO: we can find contains without converting back to bytes if f.hostMap.vpnCIDR.Contains(vpnIp.ToIP()) == false { - vpnIp = f.hostMap.queryUnsafeRoute(vpnIp) + vpnIp = f.inside.RouteFor(vpnIp) if vpnIp == 0 { return nil } diff --git a/main.go b/main.go index 7f73365..c8384d4 100644 --- a/main.go +++ b/main.go @@ -78,14 +78,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg // TODO: make sure mask is 4 bytes tunCidr := cs.certificate.Details.Ips[0] - routes, err := overlay.ParseRoutes(c, tunCidr) - if err != nil { - return nil, util.NewContextualError("Could not parse tun.routes", nil, err) - } - unsafeRoutes, err := overlay.ParseUnsafeRoutes(c, tunCidr) - if err != nil { - return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) - } ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) wireSSHReload(l, ssh, c) @@ -142,7 +134,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if !configTest { c.CatchHUP(ctx) - tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, routes, unsafeRoutes, tunFd, routines) + tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines) if err != nil { return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err) } @@ -217,8 +209,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } hostMap := NewHostMap(l, "main", tunCidr, preferredRanges) - - hostMap.addUnsafeRoutes(&unsafeRoutes) hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false) l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created") diff --git a/overlay/device.go b/overlay/device.go index 24037af..f20ef4a 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -3,6 +3,8 @@ package overlay import ( "io" "net" + + "github.com/slackhq/nebula/iputil" ) type Device interface { @@ -11,5 +13,6 @@ type Device interface { CidrNet() *net.IPNet DeviceName() string WriteRaw([]byte) error + RouteFor(iputil.VpnIp) iputil.VpnIp NewMultiQueueReader() (io.ReadWriteCloser, error) } diff --git a/overlay/route.go b/overlay/route.go index a104dab..c6a3ad4 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -16,7 +16,7 @@ type Route struct { Via *net.IP } -func ParseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { +func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { var err error r := c.Get("tun.routes") @@ -86,7 +86,7 @@ func ParseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return routes, nil } -func ParseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { +func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { var err error r := c.Get("tun.unsafe_routes") diff --git a/overlay/tun_test.go b/overlay/route_test.go similarity index 87% rename from overlay/tun_test.go rename to overlay/route_test.go index 8adac5d..2128ddb 100644 --- a/overlay/tun_test.go +++ b/overlay/route_test.go @@ -10,73 +10,73 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_ParseRoutes(t *testing.T) { +func Test_parseRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) _, n, _ := net.ParseCIDR("10.0.0.0/24") // test no routes config - routes, err := ParseRoutes(c, n) + routes, err := parseRoutes(c, n) assert.Nil(t, err) assert.Len(t, routes, 0) // not an array c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} - routes, err = ParseRoutes(c, n) + routes, err = parseRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "tun.routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} - routes, err = ParseRoutes(c, n) + routes, err = parseRoutes(c, n) assert.Nil(t, err) assert.Len(t, routes, 0) // weird route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} - routes, err = ParseRoutes(c, n) + routes, err = parseRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1 in tun.routes is invalid") // no mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} - routes, err = ParseRoutes(c, n) + routes, err = parseRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present") // bad mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} - routes, err = ParseRoutes(c, n) + routes, err = parseRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} - routes, err = ParseRoutes(c, n) + routes, err = parseRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") // missing route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} - routes, err = ParseRoutes(c, n) + routes, err = parseRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} - routes, err = ParseRoutes(c, n) + routes, err = parseRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope") // below network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} - routes, err = ParseRoutes(c, n) + routes, err = parseRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24") // above network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} - routes, err = ParseRoutes(c, n) + routes, err = parseRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24") @@ -85,7 +85,7 @@ func Test_ParseRoutes(t *testing.T) { map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"}, map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, }} - routes, err = ParseRoutes(c, n) + routes, err = parseRoutes(c, n) assert.Nil(t, err) assert.Len(t, routes, 2) @@ -106,37 +106,37 @@ func Test_ParseRoutes(t *testing.T) { } } -func Test_ParseUnsafeRoutes(t *testing.T) { +func Test_parseUnsafeRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) _, n, _ := net.ParseCIDR("10.0.0.0/24") // test no routes config - routes, err := ParseUnsafeRoutes(c, n) + routes, err := parseUnsafeRoutes(c, n) assert.Nil(t, err) assert.Len(t, routes, 0) // not an array c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "tun.unsafe_routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, err) assert.Len(t, routes, 0) // weird route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") // no via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") @@ -145,62 +145,62 @@ func Test_ParseUnsafeRoutes(t *testing.T) { 127, false, nil, 1.0, []string{"1", "2"}, } { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) } // unparsable via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope") // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope") // within network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24") // below network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Len(t, routes, 1) assert.Nil(t, err) // above network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Len(t, routes, 1) assert.Nil(t, err) // no mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Len(t, routes, 1) assert.Equal(t, DefaultMTU, routes[0].MTU) // bad mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") @@ -210,7 +210,7 @@ func Test_ParseUnsafeRoutes(t *testing.T) { map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32"}, map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} - routes, err = ParseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, err) assert.Len(t, routes, 3) diff --git a/overlay/tun.go b/overlay/tun.go index cdb6b64..abe99f7 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -1,15 +1,30 @@ package overlay import ( + "fmt" "net" + "runtime" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" ) const DefaultMTU = 1300 -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routes, unsafeRoutes []Route, fd *int, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *int, routines int) (Device, error) { + routes, err := parseRoutes(c, tunCidr) + if err != nil { + return nil, util.NewContextualError("Could not parse tun.routes", nil, err) + } + + unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr) + if err != nil { + return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) + } + routes = append(routes, unsafeRoutes...) + switch { case c.GetBool("tun.disabled", false): tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) @@ -22,7 +37,6 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, rout tunCidr, c.GetInt("tun.mtu", DefaultMTU), routes, - unsafeRoutes, c.GetInt("tun.tx_queue", 500), ) @@ -33,9 +47,22 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, rout tunCidr, c.GetInt("tun.mtu", DefaultMTU), routes, - unsafeRoutes, c.GetInt("tun.tx_queue", 500), routines > 1, ) } } + +func makeCidrTree(routes []Route, allowMTU bool) (*cidr.Tree4, error) { + cidrTree := cidr.NewTree4() + for _, r := range routes { + if !allowMTU && r.MTU > 0 { + return nil, fmt.Errorf("route MTU is not supported in %s", runtime.GOOS) + } + + if r.Via != nil { + cidrTree.AddCIDR(r.Cidr, r.Via) + } + } + return cidrTree, nil +} diff --git a/overlay/tun_android.go b/overlay/tun_android.go index ff37487..b541b9a 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -8,44 +8,43 @@ import ( "io" "net" "os" + "runtime" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/iputil" "golang.org/x/sys/unix" ) type tun struct { io.ReadWriteCloser - fd int - Device string - Cidr *net.IPNet - MaxMTU int - DefaultMTU int - TXQueueLen int - Routes []Route - UnsafeRoutes []Route - l *logrus.Logger + fd int + Cidr *net.IPNet + l *logrus.Logger } -func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, unsafeRoutes []Route, txQueueLen int) (*tun, error) { +func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) { + if len(routes) > 0 { + return nil, fmt.Errorf("routes are not supported in %s", runtime.GOOS) + } + file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") return &tun{ ReadWriteCloser: file, fd: int(file.Fd()), - Device: "android", Cidr: cidr, - DefaultMTU: defaultMTU, - TXQueueLen: txQueueLen, - Routes: routes, - UnsafeRoutes: unsafeRoutes, l: l, }, nil } -func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ []Route, _ int, _ bool) (*tun, error) { +func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } +func (t *tun) RouteFor(iputil.VpnIp) iputil.VpnIp { + return 0 +} + func (t *tun) WriteRaw(b []byte) error { var nn int for { @@ -77,7 +76,7 @@ func (t *tun) CidrNet() *net.IPNet { } func (t *tun) DeviceName() string { - return t.Device + return "android" } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 51b8048..987e03d 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -12,18 +12,20 @@ import ( "unsafe" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/iputil" netroute "golang.org/x/net/route" "golang.org/x/sys/unix" ) type tun struct { io.ReadWriteCloser - Device string - Cidr *net.IPNet - DefaultMTU int - TXQueueLen int - UnsafeRoutes []Route - l *logrus.Logger + Device string + Cidr *net.IPNet + DefaultMTU int + Routes []Route + cidrTree *cidr.Tree4 + l *logrus.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte @@ -74,9 +76,10 @@ type ifreqMTU struct { pad [8]byte } -func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, unsafeRoutes []Route, txQueueLen int, _ bool) (*tun, error) { - if len(routes) > 0 { - return nil, fmt.Errorf("route MTU not supported in Darwin") +func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) { + cidrTree, err := makeCidrTree(routes, false) + if err != nil { + return nil, err } ifIndex := -1 @@ -151,8 +154,8 @@ func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, rout Device: name, Cidr: cidr, DefaultMTU: defaultMTU, - TXQueueLen: txQueueLen, - UnsafeRoutes: unsafeRoutes, + Routes: routes, + cidrTree: cidrTree, l: l, } @@ -166,7 +169,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ []Route, _ int) (*tun, error) { +func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -279,7 +282,12 @@ func (t *tun) Activate() error { } // Unsafe path routes - for _, r := range t.UnsafeRoutes { + for _, r := range t.Routes { + if r.Via == nil { + // We don't allow route MTUs so only install routes with a via + continue + } + copy(routeAddr.IP[:], r.Cidr.IP.To4()) copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) @@ -294,6 +302,15 @@ func (t *tun) Activate() error { return nil } +func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { + r := t.cidrTree.MostSpecificContains(ip) + if r != nil { + return r.(iputil.VpnIp) + } + + return 0 +} + // Get the LinkAddr for the interface of the given name // TODO: Is there an easier way to fetch this when we create the interface? // Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers. diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index 3718223..17e11e6 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -9,6 +9,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/iputil" ) type disabledTun struct { @@ -43,6 +44,10 @@ func (*disabledTun) Activate() error { return nil } +func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp { + return 0 +} + func (t *disabledTun) CidrNet() *net.IPNet { return t.cidr } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 7e717a0..affa870 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -14,16 +14,19 @@ import ( "strings" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/iputil" ) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) type tun struct { - Device string - Cidr *net.IPNet - MTU int - UnsafeRoutes []Route - l *logrus.Logger + Device string + Cidr *net.IPNet + MTU int + Routes []Route + cidrTree *cidr.Tree4 + l *logrus.Logger io.ReadWriteCloser } @@ -35,14 +38,16 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ []Route, _ int) (*tun, error) { +func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, unsafeRoutes []Route, _ int, _ bool) (*tun, error) { - if len(routes) > 0 { - return nil, fmt.Errorf("route MTU not supported in FreeBSD") +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) { + cidrTree, err := makeCidrTree(routes, false) + if err != nil { + return nil, err } + if strings.HasPrefix(deviceName, "/dev/") { deviceName = strings.TrimPrefix(deviceName, "/dev/") } @@ -50,11 +55,12 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`") } return &tun{ - Device: deviceName, - Cidr: cidr, - MTU: defaultMTU, - UnsafeRoutes: unsafeRoutes, - l: l, + Device: deviceName, + Cidr: cidr, + MTU: defaultMTU, + Routes: routes, + cidrTree: cidrTree, + l: l, }, nil } @@ -79,7 +85,12 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to run 'ifconfig': %s", err) } // Unsafe path routes - for _, r := range t.UnsafeRoutes { + for _, r := range t.Routes { + if r.Via == nil { + // We don't allow route MTUs so only install routes with a via + continue + } + t.l.Debug("command: route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) if err = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device).Run(); err != nil { return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err) @@ -89,6 +100,15 @@ func (t *tun) Activate() error { return nil } +func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { + r := t.cidrTree.MostSpecificContains(ip) + if r != nil { + return r.(iputil.VpnIp) + } + + return 0 +} + func (t *tun) CidrNet() *net.IPNet { return t.Cidr } diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 8bc8d3c..61f69c9 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -9,25 +9,26 @@ import ( "io" "net" "os" + "runtime" "sync" "syscall" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/iputil" ) type tun struct { io.ReadWriteCloser - Device string - Cidr *net.IPNet + Cidr *net.IPNet } -func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ []Route, _ int, _ bool) (*tun, error) { +func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(_ *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ []Route, _ int) (*tun, error) { +func newTunFromFd(_ *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) { if len(routes) > 0 { - return nil, fmt.Errorf("route MTU not supported in Darwin") + return nil, fmt.Errorf("routes are not supported in %s", runtime.GOOS) } file := os.NewFile(uintptr(deviceFd), "/dev/tun") @@ -42,6 +43,10 @@ func (t *tun) Activate() error { return nil } +func (t *tun) RouteFor(iputil.VpnIp) iputil.VpnIp { + return 0 +} + func (t *tun) WriteRaw(b []byte) error { _, err := t.Write(b) return err @@ -73,7 +78,6 @@ func (tr *tunReadCloser) Read(to []byte) (int, error) { } func (tr *tunReadCloser) Write(from []byte) (int, error) { - if len(from) == 0 { return 0, syscall.EIO } @@ -111,7 +115,7 @@ func (t *tun) CidrNet() *net.IPNet { } func (t *tun) DeviceName() string { - return t.Device + return "iOS" } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index f5c29b3..87aac32 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -12,21 +12,23 @@ import ( "unsafe" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/iputil" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) type tun struct { io.ReadWriteCloser - fd int - Device string - Cidr *net.IPNet - MaxMTU int - DefaultMTU int - TXQueueLen int - Routes []Route - UnsafeRoutes []Route - l *logrus.Logger + fd int + Device string + Cidr *net.IPNet + MaxMTU int + DefaultMTU int + TXQueueLen int + Routes []Route + cidrTree *cidr.Tree4 + l *logrus.Logger } type ifReq struct { @@ -61,7 +63,11 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, unsafeRoutes []Route, txQueueLen int) (*tun, error) { +func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) { + cidrTree, err := makeCidrTree(routes, true) + if err != nil { + return nil, err + } file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -73,12 +79,12 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in DefaultMTU: defaultMTU, TXQueueLen: txQueueLen, Routes: routes, - UnsafeRoutes: unsafeRoutes, + cidrTree: cidrTree, l: l, }, nil } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, unsafeRoutes []Route, txQueueLen int, multiqueue bool) (*tun, error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err @@ -104,6 +110,11 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int } } + cidrTree, err := makeCidrTree(routes, true) + if err != nil { + return nil, err + } + return &tun{ ReadWriteCloser: file, fd: int(file.Fd()), @@ -113,7 +124,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int DefaultMTU: defaultMTU, TXQueueLen: txQueueLen, Routes: routes, - UnsafeRoutes: unsafeRoutes, + cidrTree: cidrTree, l: l, }, nil } @@ -136,6 +147,15 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return file, nil } +func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { + r := t.cidrTree.MostSpecificContains(ip) + if r != nil { + return r.(iputil.VpnIp) + } + + return 0 +} + func (t *tun) WriteRaw(b []byte) error { var nn int for { @@ -266,21 +286,8 @@ func (t tun) Activate() error { Scope: unix.RT_SCOPE_LINK, } - err = netlink.RouteAdd(&nr) - if err != nil { - return fmt.Errorf("failed to set mtu %v on route %v; %v", r.MTU, r.Cidr, err) - } - } - - // Unsafe path routes - for _, r := range t.UnsafeRoutes { - nr := netlink.Route{ - LinkIndex: link.Attrs().Index, - Dst: r.Cidr, - MTU: r.MTU, - Priority: r.Metric, - AdvMSS: t.advMSS(r), - Scope: unix.RT_SCOPE_LINK, + if r.Metric > 0 { + nr.Priority = r.Metric } err = netlink.RouteAdd(&nr) diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 014d256..2638b97 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -9,32 +9,39 @@ import ( "net" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/iputil" ) type TestTun struct { - Device string - Cidr *net.IPNet - MTU int - UnsafeRoutes []Route - l *logrus.Logger + Device string + Cidr *net.IPNet + Routes []Route + cidrTree *cidr.Tree4 + l *logrus.Logger rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, _ []Route, unsafeRoutes []Route, _ int, _ bool) (*TestTun, error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) { + cidrTree, err := makeCidrTree(routes, false) + if err != nil { + return nil, err + } + return &TestTun{ - Device: deviceName, - Cidr: cidr, - MTU: defaultMTU, - UnsafeRoutes: unsafeRoutes, - l: l, - rxPackets: make(chan []byte, 1), - TxPackets: make(chan []byte, 1), + Device: deviceName, + Cidr: cidr, + Routes: routes, + cidrTree: cidrTree, + l: l, + rxPackets: make(chan []byte, 1), + TxPackets: make(chan []byte, 1), }, nil } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ []Route, _ int) (*TestTun, error) { +func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -66,6 +73,15 @@ func (t *TestTun) Get(block bool) []byte { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// +func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { + r := t.cidrTree.MostSpecificContains(ip) + if r != nil { + return r.(iputil.VpnIp) + } + + return 0 +} + func (t *TestTun) Activate() error { return nil } diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go index d3835a0..b219fef 100644 --- a/overlay/tun_water_windows.go +++ b/overlay/tun_water_windows.go @@ -7,24 +7,33 @@ import ( "os/exec" "strconv" + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/iputil" "github.com/songgao/water" ) type waterTun struct { - Device string - Cidr *net.IPNet - MTU int - UnsafeRoutes []Route + Device string + Cidr *net.IPNet + MTU int + Routes []Route + cidrTree *cidr.Tree4 *water.Interface } -func newWaterTun(cidr *net.IPNet, defaultMTU int, unsafeRoutes []Route) (*waterTun, error) { +func newWaterTun(cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) { + cidrTree, err := makeCidrTree(routes, false) + if err != nil { + return nil, err + } + // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() return &waterTun{ - Cidr: cidr, - MTU: defaultMTU, - UnsafeRoutes: unsafeRoutes, + Cidr: cidr, + MTU: defaultMTU, + Routes: routes, + cidrTree: cidrTree, }, nil } @@ -69,7 +78,12 @@ func (t *waterTun) Activate() error { return fmt.Errorf("failed to find interface named %s: %v", t.Device, err) } - for _, r := range t.UnsafeRoutes { + for _, r := range t.Routes { + if r.Via == nil { + // We don't allow route MTUs so only install routes with a via + continue + } + err = exec.Command( "C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(iface.Index), "METRIC", strconv.Itoa(r.Metric), ).Run() @@ -81,6 +95,15 @@ func (t *waterTun) Activate() error { return nil } +func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { + r := t.cidrTree.MostSpecificContains(ip) + if r != nil { + return r.(iputil.VpnIp) + } + + return 0 +} + func (t *waterTun) CidrNet() *net.IPNet { return t.Cidr } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 57d28c2..bbd748d 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -14,11 +14,11 @@ import ( "github.com/sirupsen/logrus" ) -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ []Route, _ int) (Device, error) { +func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, unsafeRoutes []Route, _ int, _ bool) (Device, error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (Device, error) { if len(routes) > 0 { return nil, fmt.Errorf("route MTU not supported in Windows") } @@ -30,14 +30,14 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int } if useWintun { - device, err := newWinTun(deviceName, cidr, defaultMTU, unsafeRoutes) + device, err := newWinTun(deviceName, cidr, defaultMTU, routes) if err != nil { return nil, fmt.Errorf("create Wintun interface failed, %w", err) } return device, nil } - device, err := newWaterTun(cidr, defaultMTU, unsafeRoutes) + device, err := newWaterTun(cidr, defaultMTU, routes) if err != nil { return nil, fmt.Errorf("create wintap driver failed, %w", err) } diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go index 970e739..37f30e5 100644 --- a/overlay/tun_wintun_windows.go +++ b/overlay/tun_wintun_windows.go @@ -7,6 +7,8 @@ import ( "net" "unsafe" + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/wintun" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" @@ -15,10 +17,11 @@ import ( const tunGUIDLabel = "Fixed Nebula Windows GUID v1" type winTun struct { - Device string - Cidr *net.IPNet - MTU int - UnsafeRoutes []Route + Device string + Cidr *net.IPNet + MTU int + Routes []Route + cidrTree *cidr.Tree4 tun *wintun.NativeTun } @@ -42,7 +45,7 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) { return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil } -func newWinTun(deviceName string, cidr *net.IPNet, defaultMTU int, unsafeRoutes []Route) (*winTun, error) { +func newWinTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route) (*winTun, error) { guid, err := generateGUIDByDeviceName(deviceName) if err != nil { return nil, fmt.Errorf("generate GUID failed: %w", err) @@ -53,11 +56,17 @@ func newWinTun(deviceName string, cidr *net.IPNet, defaultMTU int, unsafeRoutes return nil, fmt.Errorf("create TUN device failed: %w", err) } + cidrTree, err := makeCidrTree(routes, false) + if err != nil { + return nil, err + } + return &winTun{ - Device: deviceName, - Cidr: cidr, - MTU: defaultMTU, - UnsafeRoutes: unsafeRoutes, + Device: deviceName, + Cidr: cidr, + MTU: defaultMTU, + Routes: routes, + cidrTree: cidrTree, tun: tunDevice.(*wintun.NativeTun), }, nil @@ -71,11 +80,16 @@ func (t *winTun) Activate() error { } foundDefault4 := false - routes := make([]*winipcfg.RouteData, 0, len(t.UnsafeRoutes)+1) + routes := make([]*winipcfg.RouteData, 0, len(t.Routes)+1) + + for _, r := range t.Routes { + if r.Via == nil { + // We don't allow route MTUs so only install routes with a via + continue + } - for _, r := range t.UnsafeRoutes { if !foundDefault4 { - if cidr, bits := r.Cidr.Mask.Size(); cidr == 0 && bits != 0 { + if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 { foundDefault4 = true } } @@ -110,6 +124,15 @@ func (t *winTun) Activate() error { return nil } +func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { + r := t.cidrTree.MostSpecificContains(ip) + if r != nil { + return r.(iputil.VpnIp) + } + + return 0 +} + func (t *winTun) CidrNet() *net.IPNet { return t.Cidr }