diff --git a/connection_manager.go b/connection_manager.go index f5dd594..0b277b5 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -457,7 +457,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { } if n.punchy.GetTargetEverything() { - hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) { + hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) { n.metricsTxPunchy.Inc(1) n.intf.outside.WriteTo([]byte{1}, addr) }) diff --git a/connection_manager_test.go b/connection_manager_test.go index a2607a2..f50bcf8 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -43,7 +43,9 @@ func Test_NewConnectionManagerTest(t *testing.T) { preferredRanges := []*net.IPNet{localrange} // Very incomplete mock objects - hostMap := NewHostMap(l, vpncidr, preferredRanges) + hostMap := newHostMap(l, vpncidr) + hostMap.preferredRanges.Store(&preferredRanges) + cs := &CertState{ RawCertificate: []byte{}, PrivateKey: []byte{}, @@ -123,7 +125,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) { preferredRanges := []*net.IPNet{localrange} // Very incomplete mock objects - hostMap := NewHostMap(l, vpncidr, preferredRanges) + hostMap := newHostMap(l, vpncidr) + hostMap.preferredRanges.Store(&preferredRanges) + cs := &CertState{ RawCertificate: []byte{}, PrivateKey: []byte{}, @@ -210,7 +214,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") preferredRanges := []*net.IPNet{localrange} - hostMap := NewHostMap(l, vpncidr, preferredRanges) + hostMap := newHostMap(l, vpncidr) + hostMap.preferredRanges.Store(&preferredRanges) // Generate keys for CA and peer's cert. pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader) diff --git a/control.go b/control.go index 1e27b0f..c227b20 100644 --- a/control.go +++ b/control.go @@ -145,7 +145,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH return nil } - ch := copyHostInfo(h, c.f.hostMap.preferredRanges) + ch := copyHostInfo(h, c.f.hostMap.GetPreferredRanges()) return &ch } @@ -157,7 +157,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *Control } hostInfo.SetRemote(addr.Copy()) - ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges) + ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges()) return &ch } diff --git a/control_test.go b/control_test.go index 847332b..c64a3a4 100644 --- a/control_test.go +++ b/control_test.go @@ -18,7 +18,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller - hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0)) + hm := newHostMap(l, &net.IPNet{}) + hm.preferredRanges.Store(&[]*net.IPNet{}) + remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444) remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) ipNet := net.IPNet{ diff --git a/handshake_ix.go b/handshake_ix.go index 1905c00..9107d97 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -406,7 +406,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). - WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). + WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). Info("Blocked addresses for handshakes") // Swap the packet store to benefit the original intended recipient diff --git a/handshake_manager.go b/handshake_manager.go index b568cc8..b14b0fd 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -181,7 +181,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger hostinfo := hh.hostinfo // If we are out of time, clean up if hh.counter >= hm.config.retries { - hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)). + hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())). WithField("initiatorIndex", hh.hostinfo.localIndexId). WithField("remoteIndex", hh.hostinfo.remoteIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). @@ -211,7 +211,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp) } - remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges) + remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes) // We only care about a lighthouse trigger if we have new remotes to send to. @@ -235,7 +235,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply var sentTo []*udp.Addr - hostinfo.remotes.ForEach(hm.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { + hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) { hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { @@ -362,7 +362,7 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han hm.mainHostMap.RUnlock() // Do not attempt promotion if you are a lighthouse if !hm.lightHouse.amLighthouse { - h.TryPromoteBest(hm.mainHostMap.preferredRanges, hm.f) + h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f) } return h, true } @@ -599,7 +599,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo { } func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet { - return c.mainHostMap.preferredRanges + return c.mainHostMap.GetPreferredRanges() } func (c *HandshakeManager) ForEachVpnIp(f controlEach) { diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 303aa50..9a63357 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -19,7 +19,9 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { _, localrange, _ := net.ParseCIDR("10.1.1.1/24") ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) preferredRanges := []*net.IPNet{localrange} - mainHM := NewHostMap(l, vpncidr, preferredRanges) + mainHM := newHostMap(l, vpncidr) + mainHM.preferredRanges.Store(&preferredRanges) + lh := newTestLighthouse() cs := &CertState{ diff --git a/hostmap.go b/hostmap.go index a5adeb9..589a124 100644 --- a/hostmap.go +++ b/hostmap.go @@ -11,6 +11,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" @@ -57,9 +58,8 @@ type HostMap struct { Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object RemoteIndexes map[uint32]*HostInfo Hosts map[iputil.VpnIp]*HostInfo - preferredRanges []*net.IPNet + preferredRanges atomic.Pointer[[]*net.IPNet] vpnCIDR *net.IPNet - metricsEnabled bool l *logrus.Logger } @@ -254,21 +254,53 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { - h := map[iputil.VpnIp]*HostInfo{} - i := map[uint32]*HostInfo{} - r := map[uint32]*HostInfo{} - relays := map[uint32]*HostInfo{} - m := HostMap{ - Indexes: i, - Relays: relays, - RemoteIndexes: r, - Hosts: h, - preferredRanges: preferredRanges, - vpnCIDR: vpnCIDR, - l: l, +func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap { + hm := newHostMap(l, vpnCIDR) + + hm.reload(c, true) + c.RegisterReloadCallback(func(c *config.C) { + hm.reload(c, false) + }) + + l.WithField("network", hm.vpnCIDR.String()). + WithField("preferredRanges", hm.GetPreferredRanges()). + Info("Main HostMap created") + + return hm +} + +func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap { + return &HostMap{ + Indexes: map[uint32]*HostInfo{}, + Relays: map[uint32]*HostInfo{}, + RemoteIndexes: map[uint32]*HostInfo{}, + Hosts: map[iputil.VpnIp]*HostInfo{}, + vpnCIDR: vpnCIDR, + l: l, + } +} + +func (hm *HostMap) reload(c *config.C, initial bool) { + if initial || c.HasChanged("preferred_ranges") { + var preferredRanges []*net.IPNet + rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{}) + + for _, rawPreferredRange := range rawPreferredRanges { + _, preferredRange, err := net.ParseCIDR(rawPreferredRange) + + if err != nil { + hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring") + continue + } + + preferredRanges = append(preferredRanges, preferredRange) + } + + oldRanges := hm.preferredRanges.Swap(&preferredRanges) + if !initial { + hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed") + } } - return &m } // EmitStats reports host, index, and relay counts to the stats collection system @@ -457,7 +489,7 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostI hm.RUnlock() // Do not attempt promotion if you are a lighthouse if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse { - h.TryPromoteBest(hm.preferredRanges, promoteIfce) + h.TryPromoteBest(hm.GetPreferredRanges(), promoteIfce) } return h @@ -504,7 +536,8 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { } func (hm *HostMap) GetPreferredRanges() []*net.IPNet { - return hm.preferredRanges + //NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer + return *hm.preferredRanges.Load() } func (hm *HostMap) ForEachVpnIp(f controlEach) { @@ -596,7 +629,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { // NOTE: We do this loop here instead of calling `isPreferred` in // remote_list.go so that we only have to loop over preferredRanges once. newIsPreferred := false - for _, l := range hm.preferredRanges { + for _, l := range hm.GetPreferredRanges() { // return early if we are already on a preferred remote if l.Contains(currentRemote.IP) { return false diff --git a/hostmap_test.go b/hostmap_test.go index c1c0dce..8311cef 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -4,19 +4,19 @@ import ( "net" "testing" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) func TestHostMap_MakePrimary(t *testing.T) { l := test.NewLogger() - hm := NewHostMap( + hm := newHostMap( l, &net.IPNet{ IP: net.IP{10, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}, }, - []*net.IPNet{}, ) f := &Interface{} @@ -91,13 +91,12 @@ func TestHostMap_MakePrimary(t *testing.T) { func TestHostMap_DeleteHostInfo(t *testing.T) { l := test.NewLogger() - hm := NewHostMap( + hm := newHostMap( l, &net.IPNet{ IP: net.IP{10, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}, }, - []*net.IPNet{}, ) f := &Interface{} @@ -205,3 +204,33 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { prim = hm.QueryVpnIp(1) assert.Nil(t, prim) } + +func TestHostMap_reload(t *testing.T) { + l := test.NewLogger() + c := config.NewC(l) + + hm := NewHostMapFromConfig( + l, + &net.IPNet{ + IP: net.IP{10, 0, 0, 1}, + Mask: net.IPMask{255, 255, 255, 0}, + }, + c, + ) + + toS := func(ipn []*net.IPNet) []string { + var s []string + for _, n := range ipn { + s = append(s, n.String()) + } + return s + } + + assert.Empty(t, hm.GetPreferredRanges()) + + c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]") + assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges())) + + c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]") + assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) +} diff --git a/main.go b/main.go index 8c94e80..7a0a0cf 100644 --- a/main.go +++ b/main.go @@ -183,52 +183,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } } - // Set up my internal host map - var preferredRanges []*net.IPNet - rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{}) - // First, check if 'preferred_ranges' is set and fallback to 'local_range' - if len(rawPreferredRanges) > 0 { - for _, rawPreferredRange := range rawPreferredRanges { - _, preferredRange, err := net.ParseCIDR(rawPreferredRange) - if err != nil { - return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err) - } - preferredRanges = append(preferredRanges, preferredRange) - } - } - - // local_range was superseded by preferred_ranges. If it is still present, - // merge the local_range setting into preferred_ranges. We will probably - // deprecate local_range and remove in the future. - rawLocalRange := c.GetString("local_range", "") - if rawLocalRange != "" { - _, localRange, err := net.ParseCIDR(rawLocalRange) - if err != nil { - return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err) - } - - // Check if the entry for local_range was already specified in - // preferred_ranges. Don't put it into the slice twice if so. - var found bool - for _, r := range preferredRanges { - if r.String() == localRange.String() { - found = true - break - } - } - if !found { - preferredRanges = append(preferredRanges, localRange) - } - } - - hostMap := NewHostMap(l, tunCidr, preferredRanges) - hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false) - - l. - WithField("network", hostMap.vpnCIDR.String()). - WithField("preferredRanges", hostMap.preferredRanges). - Info("Main HostMap created") - + hostMap := NewHostMapFromConfig(l, tunCidr, c) punchy := NewPunchyFromConfig(l, c) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) if err != nil { diff --git a/ssh.go b/ssh.go index e99205c..095f0fd 100644 --- a/ssh.go +++ b/ssh.go @@ -939,7 +939,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr enc.SetIndent("", " ") } - return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges)) + return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges())) } func sshReload(c *config.C, w sshd.StringWriter) error {