From e1af37e46d3e7712c23c4f26727139df82a25e87 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Mon, 13 Mar 2023 15:09:08 -0400 Subject: [PATCH] add calculated_remotes (#759) * add calculated_remotes This setting allows us to "guess" what the remote might be for a host while we wait for the lighthouse response. For networks that hard designed with in mind, it can help speed up handshake performance, as well as improve resiliency in the case that all lighthouses are down. Example: lighthouse: # ... calculated_remotes: # For any Nebula IPs in 10.0.10.0/24, this will apply the mask and add # the calculated IP as an initial remote (while we wait for the response # from the lighthouse). Both CIDRs must have the same mask size. # For example, Nebula IP 10.0.10.123 will have a calculated remote of # 192.168.1.123 10.0.10.0/24: - mask: 192.168.1.0/24 port: 4242 * figure out what is up with this test * add test * better logic for sending handshakes Keep track of the last light of hosts we sent handshakes to. Only log handshake sent messages if the list has changed. Remove the test Test_NewHandshakeManagerTrigger because it is faulty and makes no sense. It relys on the fact that no handshake packets actually get sent, but with these changes we would send packets now (which it should!) * use atomic.Pointer * cleanup to make it clearer * fix typo in example --- calculated_remote.go | 143 ++++++++++++++++++++++++++++++++++++++ calculated_remote_test.go | 27 +++++++ examples/config.yml | 13 ++++ handshake_manager.go | 36 ++++++---- handshake_manager_test.go | 40 ----------- hostmap.go | 33 ++++----- inside.go | 8 ++- lighthouse.go | 53 ++++++++++++++ udp/udp_all.go | 16 +++++ 9 files changed, 300 insertions(+), 69 deletions(-) create mode 100644 calculated_remote.go create mode 100644 calculated_remote_test.go diff --git a/calculated_remote.go b/calculated_remote.go new file mode 100644 index 0000000..910f757 --- /dev/null +++ b/calculated_remote.go @@ -0,0 +1,143 @@ +package nebula + +import ( + "fmt" + "math" + "net" + "strconv" + + "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/iputil" +) + +// This allows us to "guess" what the remote might be for a host while we wait +// for the lighthouse response. See "lighthouse.calculated_remotes" in the +// example config file. +type calculatedRemote struct { + ipNet net.IPNet + maskIP iputil.VpnIp + mask iputil.VpnIp + port uint32 +} + +func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) { + // Ensure this is an IPv4 mask that we expect + ones, bits := ipNet.Mask.Size() + if ones == 0 || bits != 32 { + return nil, fmt.Errorf("invalid mask: %v", ipNet) + } + if port < 0 || port > math.MaxUint16 { + return nil, fmt.Errorf("invalid port: %d", port) + } + + return &calculatedRemote{ + ipNet: *ipNet, + maskIP: iputil.Ip2VpnIp(ipNet.IP), + mask: iputil.Ip2VpnIp(ipNet.Mask), + port: uint32(port), + }, nil +} + +func (c *calculatedRemote) String() string { + return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) +} + +func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort { + // Combine the masked bytes of the "mask" IP with the unmasked bytes + // of the overlay IP + masked := (c.maskIP & c.mask) | (ip & ^c.mask) + + return &Ip4AndPort{Ip: uint32(masked), Port: c.port} +} + +func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4, error) { + value := c.Get(k) + if value == nil { + return nil, nil + } + + calculatedRemotes := cidr.NewTree4() + + rawMap, ok := value.(map[any]any) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) + } + for rawKey, rawValue := range rawMap { + rawCIDR, ok := rawKey.(string) + if !ok { + return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) + } + + _, ipNet, err := net.ParseCIDR(rawCIDR) + if err != nil { + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + } + + entry, err := newCalculatedRemotesListFromConfig(rawValue) + if err != nil { + return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) + } + + calculatedRemotes.AddCIDR(ipNet, entry) + } + + return calculatedRemotes, nil +} + +func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { + rawList, ok := raw.([]any) + if !ok { + return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw) + } + + var l []*calculatedRemote + for _, e := range rawList { + c, err := newCalculatedRemotesEntryFromConfig(e) + if err != nil { + return nil, fmt.Errorf("calculated_remotes entry: %w", err) + } + l = append(l, c) + } + + return l, nil +} + +func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { + rawMap, ok := raw.(map[any]any) + if !ok { + return nil, fmt.Errorf("invalid type: %T", raw) + } + + rawValue := rawMap["mask"] + if rawValue == nil { + return nil, fmt.Errorf("missing mask: %v", rawMap) + } + rawMask, ok := rawValue.(string) + if !ok { + return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue) + } + _, ipNet, err := net.ParseCIDR(rawMask) + if err != nil { + return nil, fmt.Errorf("invalid mask: %s", rawMask) + } + + var port int + rawValue = rawMap["port"] + if rawValue == nil { + return nil, fmt.Errorf("missing port: %v", rawMap) + } + switch v := rawValue.(type) { + case int: + port = v + case string: + port, err = strconv.Atoi(v) + if err != nil { + return nil, fmt.Errorf("invalid port: %s: %w", v, err) + } + default: + return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) + } + + return newCalculatedRemote(ipNet, port) +} diff --git a/calculated_remote_test.go b/calculated_remote_test.go new file mode 100644 index 0000000..2ddebca --- /dev/null +++ b/calculated_remote_test.go @@ -0,0 +1,27 @@ +package nebula + +import ( + "net" + "testing" + + "github.com/slackhq/nebula/iputil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCalculatedRemoteApply(t *testing.T) { + _, ipNet, err := net.ParseCIDR("192.168.1.0/24") + require.NoError(t, err) + + c, err := newCalculatedRemote(ipNet, 4242) + require.NoError(t, err) + + input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182}) + + expected := &Ip4AndPort{ + Ip: uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})), + Port: 4242, + } + + assert.Equal(t, expected, c.Apply(input)) +} diff --git a/examples/config.yml b/examples/config.yml index 9fe95ce..f7bb95d 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -91,6 +91,19 @@ lighthouse: #- "1.1.1.1:4242" #- "1.2.3.4:0" # port will be replaced with the real listening port + # EXPERIMENTAL: This option may change or disappear in the future. + # This setting allows us to "guess" what the remote might be for a host + # while we wait for the lighthouse response. + #calculated_remotes: + # For any Nebula IPs in 10.0.10.0/24, this will apply the mask and add + # the calculated IP as an initial remote (while we wait for the response + # from the lighthouse). Both CIDRs must have the same mask size. + # For example, Nebula IP 10.0.10.123 will have a calculated remote of + # 192.168.1.123 + #10.0.10.0/24: + #- mask: 192.168.1.0/24 + # port: 4242 + # Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined, # however using port 0 will dynamically assign a port and is recommended for roaming nodes. listen: diff --git a/handshake_manager.go b/handshake_manager.go index 06805b6..8166bda 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -142,14 +142,6 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l return } - // We only care about a lighthouse trigger before the first handshake transmit attempt. This is a very specific - // optimization for a fast lighthouse reply - //TODO: it would feel better to do this once, anytime, as our delay increases over time - if lighthouseTriggered && hostinfo.HandshakeCounter > 0 { - // If we didn't return here a lighthouse could cause us to aggressively send handshakes - return - } - // Get a remotes object if we don't already have one. // This is mainly to protect us as this should never be the case // NB ^ This comment doesn't jive. It's how the thing gets initialized. @@ -158,8 +150,22 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l hostinfo.remotes = c.lightHouse.QueryCache(vpnIp) } - //TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped) - if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 { + remotes := hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges) + remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes) + + // We only care about a lighthouse trigger if we have new remotes to send to. + // This is a very specific optimization for a fast lighthouse reply. + if lighthouseTriggered && !remotesHaveChanged { + // If we didn't return here a lighthouse could cause us to aggressively send handshakes + return + } + + hostinfo.HandshakeLastRemotes = remotes + + // TODO: this will generate a load of queries for hosts with only 1 ip + // (such as ones registered to the lighthouse with only a private IP) + // So we only do it one time after attempting 5 handshakes already. + if len(remotes) <= 1 && hostinfo.HandshakeCounter == 5 { // If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse // Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about // the learned public ip for them. Query again to short circuit the promotion counter @@ -182,12 +188,18 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l } }) - // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout - if len(sentTo) > 0 { + // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout, + // so only log when the list of remotes has changed + if remotesHaveChanged { hostinfo.logger(c.l).WithField("udpAddrs", sentTo). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("Handshake message sent") + } else if c.l.IsLevelEnabled(logrus.DebugLevel) { + hostinfo.logger(c.l).WithField("udpAddrs", sentTo). + WithField("initiatorIndex", hostinfo.localIndexId). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Debug("Handshake message sent") } if c.config.useRelays && len(hostinfo.remotes.relays) > 0 { diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 413a50a..84b8ef6 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -66,46 +66,6 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.NotContains(t, blah.pendingHostMap.Hosts, ip) } -func Test_NewHandshakeManagerTrigger(t *testing.T) { - l := test.NewLogger() - _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) - preferredRanges := []*net.IPNet{localrange} - mw := &mockEncWriter{} - mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - lh := newTestLighthouse() - - blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig) - - now := time.Now() - blah.NextOutboundHandshakeTimerTick(now, mw) - - assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) - - hi := blah.AddVpnIp(ip, nil) - hi.HandshakeReady = true - assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) - assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet") - - // Trigger the same method the channel will but, this should set our remotes pointer - blah.handleOutbound(ip, mw, true) - assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have done a handshake attempt") - assert.NotNil(t, hi.remotes, "Manager should have set my remotes pointer") - - // Make sure the trigger doesn't double schedule the timer entry - assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) - - uaddr := udp.NewAddrFromString("10.1.1.1:4242") - hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port))) - - // We now have remotes but only the first trigger should have pushed things forward - blah.handleOutbound(ip, mw, true) - assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have not done a handshake attempt") - assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) -} - func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { for _, i := range tw.t.wheel { n := i.Head diff --git a/hostmap.go b/hostmap.go index 231beb1..185ecf5 100644 --- a/hostmap.go +++ b/hostmap.go @@ -155,22 +155,23 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { type HostInfo struct { sync.RWMutex - remote *udp.Addr - remotes *RemoteList - promoteCounter atomic.Uint32 - ConnectionState *ConnectionState - handshakeStart time.Time //todo: this an entry in the handshake manager - HandshakeReady bool //todo: being in the manager means you are ready - HandshakeCounter int //todo: another handshake manager entry - HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready - HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry - packetStore []*cachedPacket //todo: this is other handshake manager entry - remoteIndexId uint32 - localIndexId uint32 - vpnIp iputil.VpnIp - recvError int - remoteCidr *cidr.Tree4 - relayState RelayState + remote *udp.Addr + remotes *RemoteList + promoteCounter atomic.Uint32 + ConnectionState *ConnectionState + handshakeStart time.Time //todo: this an entry in the handshake manager + HandshakeReady bool //todo: being in the manager means you are ready + HandshakeCounter int //todo: another handshake manager entry + HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time + HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready + HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry + packetStore []*cachedPacket //todo: this is other handshake manager entry + remoteIndexId uint32 + localIndexId uint32 + vpnIp iputil.VpnIp + recvError int + remoteCidr *cidr.Tree4 + relayState RelayState // lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH // for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like diff --git a/inside.go b/inside.go index 0734883..ddfaa20 100644 --- a/inside.go +++ b/inside.go @@ -153,7 +153,13 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { // If this is a static host, we don't need to wait for the HostQueryReply // We can trigger the handshake right now - if _, ok := f.lightHouse.GetStaticHostList()[vpnIp]; ok { + _, doTrigger := f.lightHouse.GetStaticHostList()[vpnIp] + if !doTrigger { + // Add any calculated remotes, and trigger early handshake if one found + doTrigger = f.lightHouse.addCalculatedRemotes(vpnIp) + } + + if doTrigger { select { case f.handshakeManager.trigger <- vpnIp: default: diff --git a/lighthouse.go b/lighthouse.go index 60e1f29..a3341b4 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -12,6 +12,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" @@ -72,6 +73,8 @@ type LightHouse struct { // IP's of relays that can be used by peers to access me relaysForMe atomic.Pointer[[]iputil.VpnIp] + calculatedRemotes atomic.Pointer[cidr.Tree4] // Maps VpnIp to []*calculatedRemote + metrics *MessageMetrics metricHolepunchTx metrics.Counter l *logrus.Logger @@ -161,6 +164,10 @@ func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp { return *lh.relaysForMe.Load() } +func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4 { + return lh.calculatedRemotes.Load() +} + func (lh *LightHouse) GetUpdateInterval() int64 { return lh.interval.Load() } @@ -237,6 +244,19 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } } + if initial || c.HasChanged("lighthouse.calculated_remotes") { + cr, err := NewCalculatedRemotesFromConfig(c, "lighthouse.calculated_remotes") + if err != nil { + return util.NewContextualError("Invalid lighthouse.calculated_remotes", nil, err) + } + + lh.calculatedRemotes.Store(cr) + if !initial { + //TODO: a diff will be annoyingly difficult + lh.l.Info("lighthouse.calculated_remotes has changed") + } + } + //NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config if initial || c.HasChanged("static_host_map") { staticList := make(map[iputil.VpnIp]struct{}) @@ -488,6 +508,39 @@ func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, stat staticList[vpnIp] = struct{}{} } +// addCalculatedRemotes adds any calculated remotes based on the +// lighthouse.calculated_remotes configuration. It returns true if any +// calculated remotes were added +func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { + tree := lh.getCalculatedRemotes() + if tree == nil { + return false + } + value := tree.MostSpecificContains(vpnIp) + if value == nil { + return false + } + calculatedRemotes := value.([]*calculatedRemote) + + var calculated []*Ip4AndPort + for _, cr := range calculatedRemotes { + c := cr.Apply(vpnIp) + if c != nil { + calculated = append(calculated, c) + } + } + + lh.Lock() + am := lh.unlockedGetRemoteList(vpnIp) + am.Lock() + defer am.Unlock() + lh.Unlock() + + am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4) + + return len(calculated) > 0 +} + // unlockedGetRemoteList assumes you have the lh lock func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { am, ok := lh.addrMap[vpnIp] diff --git a/udp/udp_all.go b/udp/udp_all.go index a4a462e..093bf69 100644 --- a/udp/udp_all.go +++ b/udp/udp_all.go @@ -64,6 +64,22 @@ func (ua *Addr) Copy() *Addr { return &nu } +type AddrSlice []*Addr + +func (a AddrSlice) Equal(b AddrSlice) bool { + if len(a) != len(b) { + return false + } + + for i := range a { + if !a[i].Equals(b[i]) { + return false + } + } + + return true +} + func ParseIPAndPort(s string) (net.IP, uint16, error) { rIp, sPort, err := net.SplitHostPort(s) if err != nil {