diff --git a/control_test.go b/control_test.go index ec469b4..de46991 100644 --- a/control_test.go +++ b/control_test.go @@ -47,7 +47,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { Signature: []byte{1, 2, 1, 2, 1, 3}, } - remotes := NewRemoteList() + remotes := NewRemoteList(nil) remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{ diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 5635c40..3e39e48 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.False(t, initCalled) assert.Same(t, i, i2) - i.remotes = NewRemoteList() + i.remotes = NewRemoteList(nil) i.HandshakeReady = true // Adding something to pending should not affect the main hostmap diff --git a/lighthouse.go b/lighthouse.go index 2532fc4..460a1cb 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "net/netip" "sync" "sync/atomic" "time" @@ -33,6 +34,7 @@ type netIpAndPort struct { type LightHouse struct { //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time sync.RWMutex //Because we concurrently read and write to our maps + ctx context.Context amLighthouse bool myVpnIp iputil.VpnIp myVpnZeros iputil.VpnIp @@ -82,7 +84,7 @@ type LightHouse struct { // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -100,6 +102,7 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, ones, _ := myVpnNet.Mask.Size() h := LightHouse{ + ctx: ctx, amLighthouse: amLighthouse, myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP), myVpnZeros: iputil.VpnIp(32 - ones), @@ -258,7 +261,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } //NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config - if initial || c.HasChanged("static_host_map") { + if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") { staticList := make(map[iputil.VpnIp]struct{}) err := lh.loadStaticMap(c, lh.myVpnNet, staticList) if err != nil { @@ -268,9 +271,19 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.staticList.Store(&staticList) if !initial { //TODO: we should remove any remote list entries for static hosts that were removed/modified? - lh.l.Info("static_host_map has changed") + if c.HasChanged("static_host_map") { + lh.l.Info("static_host_map has changed") + } + if c.HasChanged("static_map.cadence") { + lh.l.Info("static_map.cadence has changed") + } + if c.HasChanged("static_map.network") { + lh.l.Info("static_map.network has changed") + } + if c.HasChanged("static_map.lookup_timeout") { + lh.l.Info("static_map.lookup_timeout has changed") + } } - } if initial || c.HasChanged("lighthouse.hosts") { @@ -344,7 +357,48 @@ func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap ma return nil } +func getStaticMapCadence(c *config.C) (time.Duration, error) { + cadence := c.GetString("static_map.cadence", "30s") + d, err := time.ParseDuration(cadence) + if err != nil { + return 0, err + } + return d, nil +} + +func getStaticMapLookupTimeout(c *config.C) (time.Duration, error) { + lookupTimeout := c.GetString("static_map.lookup_timeout", "250ms") + d, err := time.ParseDuration(lookupTimeout) + if err != nil { + return 0, err + } + return d, nil +} + +func getStaticMapNetwork(c *config.C) (string, error) { + network := c.GetString("static_map.network", "ip4") + if network != "ip" && network != "ip4" && network != "ip6" { + return "", fmt.Errorf("static_map.network must be one of ip, ip4, or ip6") + } + return network, nil +} + func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error { + d, err := getStaticMapCadence(c) + if err != nil { + return err + } + + network, err := getStaticMapNetwork(c) + if err != nil { + return err + } + + lookup_timeout, err := getStaticMapLookupTimeout(c) + if err != nil { + return err + } + shm := c.GetMap("static_host_map", map[interface{}]interface{}{}) i := 0 @@ -360,21 +414,17 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList vpnIp := iputil.Ip2VpnIp(rip) vals, ok := v.([]interface{}) - if ok { - for _, v := range vals { - ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) - if err != nil { - return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) - } - lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList) - } + if !ok { + vals = []interface{}{v} + } + remoteAddrs := []string{} + for _, v := range vals { + remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v)) + } - } else { - ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) - if err != nil { - return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) - } - lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList) + err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList) + if err != nil { + return err } i++ } @@ -482,30 +532,47 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it -func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, staticList map[iputil.VpnIp]struct{}) { +func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error { lh.Lock() am := lh.unlockedGetRemoteList(vpnIp) am.Lock() defer am.Unlock() + ctx := lh.ctx lh.Unlock() - if ipv4 := toAddr.IP.To4(); ipv4 != nil { - to := NewIp4AndPort(ipv4, uint32(toAddr.Port)) - if !lh.unlockedShouldAddV4(vpnIp, to) { - return - } - am.unlockedPrependV4(lh.myVpnIp, to) + hr, err := NewHostnameResults(ctx, lh.l, d, network, timeout, toAddrs, func() { + // This callback runs whenever the DNS hostname resolver finds a different set of IP's + // in its resolution for hostnames. + am.Lock() + defer am.Unlock() + am.shouldRebuild = true + }) + if err != nil { + return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) + } + am.unlockedSetHostnamesResults(hr) - } else { - to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)) - if !lh.unlockedShouldAddV6(vpnIp, to) { - return + for _, addrPort := range hr.GetIPs() { + + switch { + case addrPort.Addr().Is4(): + to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) + if !lh.unlockedShouldAddV4(vpnIp, to) { + continue + } + am.unlockedPrependV4(lh.myVpnIp, to) + case addrPort.Addr().Is6(): + to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) + if !lh.unlockedShouldAddV6(vpnIp, to) { + continue + } + am.unlockedPrependV6(lh.myVpnIp, to) } - am.unlockedPrependV6(lh.myVpnIp, to) } // Mark it as static in the caller provided map staticList[vpnIp] = struct{}{} + return nil } // addCalculatedRemotes adds any calculated remotes based on the @@ -545,12 +612,42 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { am, ok := lh.addrMap[vpnIp] if !ok { - am = NewRemoteList() + am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) }) lh.addrMap[vpnIp] = am } return am } +func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool { + switch { + case to.Is4(): + ipBytes := to.As4() + ip := iputil.Ip2VpnIp(ipBytes[:]) + allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip) + if lh.l.Level >= logrus.TraceLevel { + lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") + } + if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) { + return false + } + case to.Is6(): + ipBytes := to.As16() + + hi := binary.BigEndian.Uint64(ipBytes[:8]) + lo := binary.BigEndian.Uint64(ipBytes[8:]) + allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo) + if lh.l.Level >= logrus.TraceLevel { + lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow") + } + + // We don't check our vpn network here because nebula does not support ipv6 on the inside + if !allow { + return false + } + } + return true +} + // unlockedShouldAddV4 checks if to is allowed by our allow list func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool { allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip)) @@ -609,6 +706,14 @@ func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort { return &ipp } +func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { + v4Addr := ip.As4() + return &Ip4AndPort{ + Ip: binary.BigEndian.Uint32(v4Addr[:]), + Port: uint32(port), + } +} + func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort { return &Ip6AndPort{ Hi: binary.BigEndian.Uint64(ip[:8]), @@ -617,6 +722,14 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort { } } +func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { + ip6Addr := ip.As16() + return &Ip6AndPort{ + Hi: binary.BigEndian.Uint64(ip6Addr[:8]), + Lo: binary.BigEndian.Uint64(ip6Addr[8:]), + Port: uint32(port), + } +} func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr { ip := ipp.Ip return udp.NewAddr( diff --git a/lighthouse_test.go b/lighthouse_test.go index 658c087..aa4da4c 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -1,6 +1,7 @@ package nebula import ( + "context" "fmt" "net" "testing" @@ -53,14 +54,14 @@ func Test_lhStaticMapping(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - _, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) + _, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) assert.Nil(t, err) lh2 := "10.128.0.3" c = config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} - _, err = NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) + _, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } @@ -69,14 +70,14 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0") c := config.NewC(l) - lh, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) if !assert.NoError(b, err) { b.Fatal() } hAddr := udp.NewAddrFromString("4.5.6.7:12345") hAddr2 := udp.NewAddrFromString("4.5.6.7:12346") - lh.addrMap[3] = NewRemoteList() + lh.addrMap[3] = NewRemoteList(nil) lh.addrMap[3].unlockedSetV4( 3, 3, @@ -89,7 +90,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { rAddr := udp.NewAddrFromString("1.2.2.3:12345") rAddr2 := udp.NewAddrFromString("1.2.2.3:12346") - lh.addrMap[2] = NewRemoteList() + lh.addrMap[2] = NewRemoteList(nil) lh.addrMap[2].unlockedSetV4( 3, 3, @@ -162,7 +163,7 @@ func TestLighthouse_Memory(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) assert.NoError(t, err) lhh := lh.NewRequestHandler() @@ -238,7 +239,7 @@ func TestLighthouse_reload(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) assert.NoError(t, err) c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}} diff --git a/main.go b/main.go index bbf831a..4d604f5 100644 --- a/main.go +++ b/main.go @@ -226,7 +226,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg */ punchy := NewPunchyFromConfig(l, c) - lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy) + lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) switch { case errors.As(err, &util.ContextualError{}): return nil, err diff --git a/remote_list.go b/remote_list.go index 4b544f6..4540714 100644 --- a/remote_list.go +++ b/remote_list.go @@ -2,10 +2,16 @@ package nebula import ( "bytes" + "context" "net" + "net/netip" "sort" + "strconv" "sync" + "sync/atomic" + "time" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" ) @@ -55,6 +61,132 @@ type cacheV6 struct { reported []*Ip6AndPort } +type hostnamePort struct { + name string + port uint16 +} + +type hostnamesResults struct { + hostnames []hostnamePort + network string + lookupTimeout time.Duration + stop chan struct{} + l *logrus.Logger + ips atomic.Pointer[map[netip.AddrPort]struct{}] +} + +func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) { + r := &hostnamesResults{ + hostnames: make([]hostnamePort, len(hostPorts)), + network: network, + lookupTimeout: timeout, + stop: make(chan (struct{})), + l: l, + } + + // Fastrack IP addresses to ensure they're immediately available for use. + // DNS lookups for hostnames that aren't hardcoded IP's will happen in a background goroutine. + performBackgroundLookup := false + ips := map[netip.AddrPort]struct{}{} + for idx, hostPort := range hostPorts { + + rIp, sPort, err := net.SplitHostPort(hostPort) + if err != nil { + return nil, err + } + + iPort, err := strconv.Atoi(sPort) + if err != nil { + return nil, err + } + + r.hostnames[idx] = hostnamePort{name: rIp, port: uint16(iPort)} + addr, err := netip.ParseAddr(rIp) + if err != nil { + // This address is a hostname, not an IP address + performBackgroundLookup = true + continue + } + + // Save the IP address immediately + ips[netip.AddrPortFrom(addr, uint16(iPort))] = struct{}{} + } + r.ips.Store(&ips) + + // Time for the DNS lookup goroutine + if performBackgroundLookup { + ticker := time.NewTicker(d) + go func() { + defer ticker.Stop() + for { + netipAddrs := map[netip.AddrPort]struct{}{} + for _, hostPort := range r.hostnames { + timeoutCtx, timeoutCancel := context.WithTimeout(ctx, r.lookupTimeout) + addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name) + timeoutCancel() + if err != nil { + l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host") + continue + } + for _, a := range addrs { + netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{} + } + } + origSet := r.ips.Load() + different := false + for a := range *origSet { + if _, ok := netipAddrs[a]; !ok { + different = true + break + } + } + if !different { + for a := range netipAddrs { + if _, ok := (*origSet)[a]; !ok { + different = true + break + } + } + } + if different { + l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list") + r.ips.Store(&netipAddrs) + onUpdate() + } + select { + case <-ctx.Done(): + return + case <-r.stop: + return + case <-ticker.C: + continue + } + } + }() + } + + return r, nil +} + +func (hr *hostnamesResults) Cancel() { + if hr != nil { + hr.stop <- struct{}{} + } +} + +func (hr *hostnamesResults) GetIPs() []netip.AddrPort { + var retSlice []netip.AddrPort + if hr != nil { + p := hr.ips.Load() + if p != nil { + for k := range *p { + retSlice = append(retSlice, k) + } + } + } + return retSlice +} + // RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos. // It serves as a local cache of query replies, host update notifications, and locally learned addresses type RemoteList struct { @@ -72,6 +204,9 @@ type RemoteList struct { // For learned addresses, this is the vpnIp that sent the packet cache map[iputil.VpnIp]*cache + hr *hostnamesResults + shouldAdd func(netip.Addr) bool + // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // They should not be tried again during a handshake badRemotes []*udp.Addr @@ -81,14 +216,21 @@ type RemoteList struct { } // NewRemoteList creates a new empty RemoteList -func NewRemoteList() *RemoteList { +func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { return &RemoteList{ - addrs: make([]*udp.Addr, 0), - relays: make([]*iputil.VpnIp, 0), - cache: make(map[iputil.VpnIp]*cache), + addrs: make([]*udp.Addr, 0), + relays: make([]*iputil.VpnIp, 0), + cache: make(map[iputil.VpnIp]*cache), + shouldAdd: shouldAdd, } } +func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { + // Cancel any existing hostnamesResults DNS goroutine to release resources + r.hr.Cancel() + r.hr = hr +} + // Len locks and reports the size of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { @@ -437,6 +579,26 @@ func (r *RemoteList) unlockedCollect() { } } + dnsAddrs := r.hr.GetIPs() + for _, addr := range dnsAddrs { + if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { + switch { + case addr.Addr().Is4(): + v4 := addr.Addr().As4() + addrs = append(addrs, &udp.Addr{ + IP: v4[:], + Port: addr.Port(), + }) + case addr.Addr().Is6(): + v6 := addr.Addr().As16() + addrs = append(addrs, &udp.Addr{ + IP: v6[:], + Port: addr.Port(), + }) + } + } + } + r.addrs = addrs r.relays = relays diff --git a/remote_list_test.go b/remote_list_test.go index 2170930..49aa171 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -9,7 +9,7 @@ import ( ) func TestRemoteList_Rebuild(t *testing.T) { - rl := NewRemoteList() + rl := NewRemoteList(nil) rl.unlockedSetV4( 0, 0, @@ -102,7 +102,7 @@ func TestRemoteList_Rebuild(t *testing.T) { } func BenchmarkFullRebuild(b *testing.B) { - rl := NewRemoteList() + rl := NewRemoteList(nil) rl.unlockedSetV4( 0, 0, @@ -167,7 +167,7 @@ func BenchmarkFullRebuild(b *testing.B) { } func BenchmarkSortRebuild(b *testing.B) { - rl := NewRemoteList() + rl := NewRemoteList(nil) rl.unlockedSetV4( 0, 0,