diff --git a/connection_manager.go b/connection_manager.go index e76cd95..0ea1f75 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -5,49 +5,55 @@ import ( "sync" "time" + "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" ) -// TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet -// and something like every 10 packets we could lock, send 10, then unlock for a moment - type connectionManager struct { - hostMap *HostMap - in map[uint32]struct{} - inLock *sync.RWMutex - out map[uint32]struct{} - outLock *sync.RWMutex - TrafficTimer *LockingTimerWheel[uint32] - intf *Interface + in map[uint32]struct{} + inLock *sync.RWMutex - pendingDeletion map[uint32]int - pendingDeletionLock *sync.RWMutex - pendingDeletionTimer *LockingTimerWheel[uint32] + out map[uint32]struct{} + outLock *sync.RWMutex - checkInterval int - pendingDeletionInterval int + hostMap *HostMap + trafficTimer *LockingTimerWheel[uint32] + intf *Interface + pendingDeletion map[uint32]struct{} + punchy *Punchy + checkInterval time.Duration + pendingDeletionInterval time.Duration + metricsTxPunchy metrics.Counter l *logrus.Logger - // I wanted to call one matLock } -func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager { +func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager { + var max time.Duration + if checkInterval < pendingDeletionInterval { + max = pendingDeletionInterval + } else { + max = checkInterval + } + nc := &connectionManager{ hostMap: intf.hostMap, in: make(map[uint32]struct{}), inLock: &sync.RWMutex{}, out: make(map[uint32]struct{}), outLock: &sync.RWMutex{}, - TrafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, time.Second*60), + trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max), intf: intf, - pendingDeletion: make(map[uint32]int), - pendingDeletionLock: &sync.RWMutex{}, - pendingDeletionTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, time.Second*60), + pendingDeletion: make(map[uint32]struct{}), checkInterval: checkInterval, pendingDeletionInterval: pendingDeletionInterval, + punchy: punchy, + metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), l: l, } + nc.Start(ctx) return nc } @@ -74,65 +80,27 @@ func (n *connectionManager) Out(localIndex uint32) { } n.outLock.RUnlock() n.outLock.Lock() - // double check since we dropped the lock temporarily - if _, ok := n.out[localIndex]; ok { - n.outLock.Unlock() - return - } n.out[localIndex] = struct{}{} - n.AddTrafficWatch(localIndex, n.checkInterval) n.outLock.Unlock() } -func (n *connectionManager) CheckIn(localIndex uint32) bool { - n.inLock.RLock() - if _, ok := n.in[localIndex]; ok { - n.inLock.RUnlock() - return true - } - n.inLock.RUnlock() - return false -} - -func (n *connectionManager) ClearLocalIndex(localIndex uint32) { +// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and +// resets the state for this local index +func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) { n.inLock.Lock() n.outLock.Lock() + _, in := n.in[localIndex] + _, out := n.out[localIndex] delete(n.in, localIndex) delete(n.out, localIndex) n.inLock.Unlock() n.outLock.Unlock() + return in, out } -func (n *connectionManager) ClearPendingDeletion(localIndex uint32) { - n.pendingDeletionLock.Lock() - delete(n.pendingDeletion, localIndex) - n.pendingDeletionLock.Unlock() -} - -func (n *connectionManager) AddPendingDeletion(localIndex uint32) { - n.pendingDeletionLock.Lock() - if _, ok := n.pendingDeletion[localIndex]; ok { - n.pendingDeletion[localIndex] += 1 - } else { - n.pendingDeletion[localIndex] = 0 - } - n.pendingDeletionTimer.Add(localIndex, time.Second*time.Duration(n.pendingDeletionInterval)) - n.pendingDeletionLock.Unlock() -} - -func (n *connectionManager) checkPendingDeletion(localIndex uint32) bool { - n.pendingDeletionLock.RLock() - if _, ok := n.pendingDeletion[localIndex]; ok { - - n.pendingDeletionLock.RUnlock() - return true - } - n.pendingDeletionLock.RUnlock() - return false -} - -func (n *connectionManager) AddTrafficWatch(localIndex uint32, seconds int) { - n.TrafficTimer.Add(localIndex, time.Second*time.Duration(seconds)) +func (n *connectionManager) AddTrafficWatch(localIndex uint32) { + n.Out(localIndex) + n.trafficTimer.Add(localIndex, n.checkInterval) } func (n *connectionManager) Start(ctx context.Context) { @@ -140,6 +108,7 @@ func (n *connectionManager) Start(ctx context.Context) { } func (n *connectionManager) Run(ctx context.Context) { + //TODO: this tick should be based on the min wheel tick? Check firewall clockSource := time.NewTicker(500 * time.Millisecond) defer clockSource.Stop() @@ -151,151 +120,106 @@ func (n *connectionManager) Run(ctx context.Context) { select { case <-ctx.Done(): return + case now := <-clockSource.C: - n.HandleMonitorTick(now, p, nb, out) - n.HandleDeletionTick(now) - } - } -} - -func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) { - n.TrafficTimer.Advance(now) - for { - localIndex, has := n.TrafficTimer.Purge() - if !has { - break - } - - // Check for traffic coming back in from this host. - traf := n.CheckIn(localIndex) - - hostinfo, err := n.hostMap.QueryIndex(localIndex) - if err != nil { - n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap") - n.ClearLocalIndex(localIndex) - n.ClearPendingDeletion(localIndex) - continue - } - - if n.handleInvalidCertificate(now, hostinfo) { - continue - } - - // Does the vpnIp point to this hostinfo or is it ancillary? If we have ancillary hostinfos then we need to - // decide if this should be the main hostinfo if we are seeing traffic on it - primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp) - mainHostInfo := true - if primary != nil && primary != hostinfo { - mainHostInfo = false - } - - // If we saw an incoming packets from this ip and peer's certificate is not - // expired, just ignore. - if traf { - if n.l.Level >= logrus.DebugLevel { - hostinfo.logger(n.l). - WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). - Debug("Tunnel status") - } - n.ClearLocalIndex(localIndex) - n.ClearPendingDeletion(localIndex) - - if !mainHostInfo { - if hostinfo.vpnIp > n.intf.myVpnIp { - // We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make - // This the primary and prime the old primary hostinfo for testing - n.hostMap.MakePrimary(hostinfo) - n.Out(primary.localIndexId) - } else { - // This hostinfo is still being used despite not being the primary hostinfo for this vpn ip - // Keep tracking so that we can tear it down when it goes away - n.Out(hostinfo.localIndexId) + n.trafficTimer.Advance(now) + for { + localIndex, has := n.trafficTimer.Purge() + if !has { + break } + + n.doTrafficCheck(localIndex, p, nb, out, now) } - - continue } - - if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) { - // Don't probe lighthouses since recv_error should naturally catch this. - n.ClearLocalIndex(localIndex) - n.ClearPendingDeletion(localIndex) - continue - } - - hostinfo.logger(n.l). - WithField("tunnelCheck", m{"state": "testing", "method": "active"}). - Debug("Tunnel status") - - if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { - // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out) - - } else { - hostinfo.logger(n.l).Debugf("Hostinfo sadness") - } - n.AddPendingDeletion(localIndex) } - } -func (n *connectionManager) HandleDeletionTick(now time.Time) { - n.pendingDeletionTimer.Advance(now) - for { - localIndex, has := n.pendingDeletionTimer.Purge() - if !has { - break - } - - hostinfo, mainHostInfo, err := n.hostMap.QueryIndexIsPrimary(localIndex) - if err != nil { - n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap") - n.ClearLocalIndex(localIndex) - n.ClearPendingDeletion(localIndex) - continue - } - - if n.handleInvalidCertificate(now, hostinfo) { - continue - } - - // If we saw an incoming packets from this ip and peer's certificate is not - // expired, just ignore. - traf := n.CheckIn(localIndex) - if traf { - hostinfo.logger(n.l). - WithField("tunnelCheck", m{"state": "alive", "method": "active"}). - Debug("Tunnel status") - - n.ClearLocalIndex(localIndex) - n.ClearPendingDeletion(localIndex) - - if !mainHostInfo { - // This hostinfo is still being used despite not being the primary hostinfo for this vpn ip - // Keep tracking so that we can tear it down when it goes away - n.Out(localIndex) - } - continue - } - - // If it comes around on deletion wheel and hasn't resolved itself, delete - if n.checkPendingDeletion(localIndex) { - cn := "" - if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil { - cn = hostinfo.ConnectionState.peerCert.Details.Name - } - - hostinfo.logger(n.l). - WithField("tunnelCheck", m{"state": "dead", "method": "active"}). - WithField("certName", cn). - Info("Tunnel status") - - n.hostMap.DeleteHostInfo(hostinfo) - } - - n.ClearLocalIndex(localIndex) - n.ClearPendingDeletion(localIndex) +func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { + hostinfo, err := n.hostMap.QueryIndex(localIndex) + if err != nil { + n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap") + delete(n.pendingDeletion, localIndex) + return } + + if n.handleInvalidCertificate(now, hostinfo) { + return + } + + primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp) + mainHostInfo := true + if primary != nil && primary != hostinfo { + mainHostInfo = false + } + + // Check for traffic on this hostinfo + inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex) + + // A hostinfo is determined alive if there is incoming traffic + if inTraffic { + if n.l.Level >= logrus.DebugLevel { + hostinfo.logger(n.l). + WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). + Debug("Tunnel status") + } + delete(n.pendingDeletion, hostinfo.localIndexId) + + if !mainHostInfo { + if hostinfo.vpnIp > n.intf.myVpnIp { + // We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make + // This the primary and prime the old primary hostinfo for testing + n.hostMap.MakePrimary(hostinfo) + } + } + + n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) + + if !outTraffic { + // Send a punch packet to keep the NAT state alive + n.sendPunch(hostinfo) + } + + return + } + + if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) { + // We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel + n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) + return + } + + if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok { + // We have already sent a test packet and nothing was returned, this hostinfo is dead + hostinfo.logger(n.l). + WithField("tunnelCheck", m{"state": "dead", "method": "active"}). + Info("Tunnel status") + + n.hostMap.DeleteHostInfo(hostinfo) + delete(n.pendingDeletion, hostinfo.localIndexId) + return + } + + hostinfo.logger(n.l). + WithField("tunnelCheck", m{"state": "testing", "method": "active"}). + Debug("Tunnel status") + + if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { + if n.punchy.GetTargetEverything() { + // Maybe the remote is sending us packets but our NAT is blocking it and since we are configured to punch to all + // known remotes, go ahead and do that AND send a test packet + n.sendPunch(hostinfo) + } + + // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues + n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out) + + } else { + hostinfo.logger(n.l).Debugf("Hostinfo sadness") + } + + n.pendingDeletion[hostinfo.localIndexId] = struct{}{} + n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval) } // handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid @@ -322,8 +246,24 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *Ho // Inform the remote and close the tunnel locally n.intf.sendCloseTunnel(hostinfo) n.intf.closeTunnel(hostinfo) - - n.ClearLocalIndex(hostinfo.localIndexId) - n.ClearPendingDeletion(hostinfo.localIndexId) + delete(n.pendingDeletion, hostinfo.localIndexId) return true } + +func (n *connectionManager) sendPunch(hostinfo *HostInfo) { + if !n.punchy.GetPunch() { + // Punching is disabled + return + } + + if n.punchy.GetTargetEverything() { + hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) { + n.metricsTxPunchy.Inc(1) + n.intf.outside.WriteTo([]byte{1}, addr) + }) + + } else if hostinfo.remote != nil { + n.metricsTxPunchy.Inc(1) + n.intf.outside.WriteTo([]byte{1}, hostinfo.remote) + } +} diff --git a/connection_manager_test.go b/connection_manager_test.go index b02c1bf..e05376d 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -10,6 +10,7 @@ import ( "github.com/flynn/noise" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" @@ -60,16 +61,16 @@ func Test_NewConnectionManagerTest(t *testing.T) { l: l, } ifce.certState.Store(cs) - now := time.Now() // Create manager ctx, cancel := context.WithCancel(context.Background()) defer cancel() - nc := newConnectionManager(ctx, l, ifce, 5, 10) + punchy := NewPunchyFromConfig(l, config.NewC(l)) + nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) - nc.HandleMonitorTick(now, p, nb, out) + // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ vpnIp: vpnIp, @@ -84,26 +85,28 @@ func Test_NewConnectionManagerTest(t *testing.T) { // We saw traffic out to vpnIp nc.Out(hostinfo.localIndexId) + nc.In(hostinfo.localIndexId) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - // Move ahead 5s. Nothing should happen - next_tick := now.Add(5 * time.Second) - nc.HandleMonitorTick(next_tick, p, nb, out) - nc.HandleDeletionTick(next_tick) - // Move ahead 6s. We haven't heard back - next_tick = now.Add(6 * time.Second) - nc.HandleMonitorTick(next_tick, p, nb, out) - nc.HandleDeletionTick(next_tick) - // This host should now be up for deletion + assert.Contains(t, nc.out, hostinfo.localIndexId) + + // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded + nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) + assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) + assert.NotContains(t, nc.out, hostinfo.localIndexId) + assert.NotContains(t, nc.in, hostinfo.localIndexId) + + // Do another traffic check tick, this host should be pending deletion now + nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.NotContains(t, nc.out, hostinfo.localIndexId) + assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - // Move ahead some more - next_tick = now.Add(45 * time.Second) - nc.HandleMonitorTick(next_tick, p, nb, out) - nc.HandleDeletionTick(next_tick) - // The host should be evicted + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + + // Do a final traffic check tick, the host should now be removed + nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) @@ -136,16 +139,16 @@ func Test_NewConnectionManagerTest2(t *testing.T) { l: l, } ifce.certState.Store(cs) - now := time.Now() // Create manager ctx, cancel := context.WithCancel(context.Background()) defer cancel() - nc := newConnectionManager(ctx, l, ifce, 5, 10) + punchy := NewPunchyFromConfig(l, config.NewC(l)) + nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) - nc.HandleMonitorTick(now, p, nb, out) + // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ vpnIp: vpnIp, @@ -160,30 +163,33 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // We saw traffic out to vpnIp nc.Out(hostinfo.localIndexId) - assert.NotContains(t, nc.pendingDeletion, vpnIp) - assert.Contains(t, nc.hostMap.Hosts, vpnIp) - // Move ahead 5s. Nothing should happen - next_tick := now.Add(5 * time.Second) - nc.HandleMonitorTick(next_tick, p, nb, out) - nc.HandleDeletionTick(next_tick) - // Move ahead 6s. We haven't heard back - next_tick = now.Add(6 * time.Second) - nc.HandleMonitorTick(next_tick, p, nb, out) - nc.HandleDeletionTick(next_tick) - // This host should now be up for deletion - assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, vpnIp) - assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - // We heard back this time nc.In(hostinfo.localIndexId) - // Move ahead some more - next_tick = now.Add(45 * time.Second) - nc.HandleMonitorTick(next_tick, p, nb, out) - nc.HandleDeletionTick(next_tick) - // The host should not be evicted - assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) + assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) + + // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded + nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) + assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) + assert.NotContains(t, nc.out, hostinfo.localIndexId) + assert.NotContains(t, nc.in, hostinfo.localIndexId) + + // Do another traffic check tick, this host should be pending deletion now + nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) + assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) + assert.NotContains(t, nc.out, hostinfo.localIndexId) + assert.NotContains(t, nc.in, hostinfo.localIndexId) + assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + + // We saw traffic, should no longer be pending deletion + nc.In(hostinfo.localIndexId) + nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) + assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) + assert.NotContains(t, nc.out, hostinfo.localIndexId) + assert.NotContains(t, nc.in, hostinfo.localIndexId) + assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) } // Check if we can disconnect the peer. @@ -257,7 +263,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { // Create manager ctx, cancel := context.WithCancel(context.Background()) defer cancel() - nc := newConnectionManager(ctx, l, ifce, 5, 10) + punchy := NewPunchyFromConfig(l, config.NewC(l)) + nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) ifce.connectionManager = nc hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil) hostinfo.ConnectionState = &ConnectionState{ diff --git a/handshake_ix.go b/handshake_ix.go index bb511cc..a51fb31 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -332,12 +332,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b Info("Handshake message sent") } - if existing != nil { - // Make sure we are tracking the old primary if there was one, it needs to go away eventually - f.connectionManager.Out(existing.localIndexId) - } - - f.connectionManager.Out(hostinfo.localIndexId) + f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics) return @@ -495,12 +490,8 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo * hostinfo.CreateRemoteCIDR(remoteCert) // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp - existing := f.handshakeManager.Complete(hostinfo, f) - if existing != nil { - // Make sure we are tracking the old primary if there was one, it needs to go away eventually - f.connectionManager.Out(existing.localIndexId) - } - + f.handshakeManager.Complete(hostinfo, f) + f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics) f.metricHandshakes.Update(duration) diff --git a/handshake_manager.go b/handshake_manager.go index e93cbee..449a4da 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -380,7 +380,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket // Complete is a simpler version of CheckAndComplete when we already know we // won't have a localIndexId collision because we already have an entry in the // pendingHostMap. An existing hostinfo is returned if there was one. -func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) *HostInfo { +func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { c.pendingHostMap.Lock() defer c.pendingHostMap.Unlock() c.mainHostMap.Lock() @@ -395,11 +395,9 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) *HostInfo Info("New host shadows existing host remoteIndex") } - existingHostInfo := c.mainHostMap.Hosts[hostinfo.vpnIp] // We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap. c.pendingHostMap.unlockedDeleteHostInfo(hostinfo) c.mainHostMap.unlockedAddHostInfo(hostinfo, f) - return existingHostInfo } // AddIndexHostInfo generates a unique localIndexId for this HostInfo diff --git a/hostmap.go b/hostmap.go index a27a7f9..ebfb840 100644 --- a/hostmap.go +++ b/hostmap.go @@ -1,7 +1,6 @@ package nebula import ( - "context" "errors" "fmt" "net" @@ -621,54 +620,6 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { } } -// punchList assembles a list of all non nil RemoteList pointer entries in this hostmap -// The caller can then do the its work outside of the read lock -func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList { - hm.RLock() - defer hm.RUnlock() - - for _, v := range hm.Hosts { - if v.remotes != nil { - rl = append(rl, v.remotes) - } - } - return rl -} - -// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them -func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) { - var metricsTxPunchy metrics.Counter - if hm.metricsEnabled { - metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil) - } else { - metricsTxPunchy = metrics.NilCounter{} - } - - var remotes []*RemoteList - b := []byte{1} - - clockSource := time.NewTicker(time.Second * 10) - defer clockSource.Stop() - - for { - remotes = hm.punchList(remotes[:0]) - for _, rl := range remotes { - //TODO: CopyAddrs generates garbage but ForEach locks for the work here, figure out which way is better - for _, addr := range rl.CopyAddrs(hm.preferredRanges) { - metricsTxPunchy.Inc(1) - conn.WriteTo(b, addr) - } - } - - select { - case <-ctx.Done(): - return - case <-clockSource.C: - continue - } - } -} - // TryPromoteBest handles re-querying lighthouses and probing for better paths // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { diff --git a/interface.go b/interface.go index cc6e781..b4822ed 100644 --- a/interface.go +++ b/interface.go @@ -33,8 +33,8 @@ type InterfaceConfig struct { ServeDns bool HandshakeManager *HandshakeManager lightHouse *LightHouse - checkInterval int - pendingDeletionInterval int + checkInterval time.Duration + pendingDeletionInterval time.Duration DropLocalBroadcast bool DropMulticast bool routines int @@ -43,6 +43,7 @@ type InterfaceConfig struct { caPool *cert.NebulaCAPool disconnectInvalid bool relayManager *relayManager + punchy *Punchy ConntrackCacheTimeout time.Duration l *logrus.Logger @@ -172,7 +173,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { } ifce.certState.Store(c.certState) - ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval) + ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy) return ifce, nil } diff --git a/main.go b/main.go index 99fe72c..f9ea77c 100644 --- a/main.go +++ b/main.go @@ -213,11 +213,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg */ punchy := NewPunchyFromConfig(l, c) - if punchy.GetPunch() && !configTest { - l.Info("UDP hole punching enabled") - go hostMap.Punchy(ctx, udpConns[0]) - } - lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy) switch { case errors.As(err, &util.ContextualError{}): @@ -272,8 +267,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg ServeDns: serveDns, HandshakeManager: handshakeManager, lightHouse: lightHouse, - checkInterval: checkInterval, - pendingDeletionInterval: pendingDeletionInterval, + checkInterval: time.Second * time.Duration(checkInterval), + pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval), DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false), DropMulticast: c.GetBool("tun.drop_multicast", false), routines: routines, @@ -282,6 +277,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg caPool: caPool, disconnectInvalid: c.GetBool("pki.disconnect_invalid", false), relayManager: NewRelayManager(ctx, l, hostMap, c), + punchy: punchy, ConntrackCacheTimeout: conntrackCacheTimeout, l: l, diff --git a/outside.go b/outside.go index 9d51786..fd6f0a3 100644 --- a/outside.go +++ b/outside.go @@ -238,9 +238,6 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote func (f *Interface) closeTunnel(hostInfo *HostInfo) { - //TODO: this would be better as a single function in ConnectionManager that handled locks appropriately - f.connectionManager.ClearLocalIndex(hostInfo.localIndexId) - f.connectionManager.ClearPendingDeletion(hostInfo.localIndexId) final := f.hostMap.DeleteHostInfo(hostInfo) if final { // We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage diff --git a/punchy.go b/punchy.go index a930ac5..c0bdbd3 100644 --- a/punchy.go +++ b/punchy.go @@ -9,11 +9,12 @@ import ( ) type Punchy struct { - punch atomic.Bool - respond atomic.Bool - delay atomic.Int64 - respondDelay atomic.Int64 - l *logrus.Logger + punch atomic.Bool + respond atomic.Bool + delay atomic.Int64 + respondDelay atomic.Int64 + punchEverything atomic.Bool + l *logrus.Logger } func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy { @@ -38,6 +39,12 @@ func (p *Punchy) reload(c *config.C, initial bool) { } p.punch.Store(yes) + if yes { + p.l.Info("punchy enabled") + } else { + p.l.Info("punchy disabled") + } + } else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") { //TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.") @@ -66,6 +73,14 @@ func (p *Punchy) reload(c *config.C, initial bool) { p.l.Infof("punchy.delay changed to %s", p.GetDelay()) } } + + if initial || c.HasChanged("punchy.target_all_remotes") { + p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", true)) + if !initial { + p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed") + } + } + if initial || c.HasChanged("punchy.respond_delay") { p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second))) if !initial { @@ -89,3 +104,7 @@ func (p *Punchy) GetDelay() time.Duration { func (p *Punchy) GetRespondDelay() time.Duration { return (time.Duration)(p.respondDelay.Load()) } + +func (p *Punchy) GetTargetEverything() bool { + return p.punchEverything.Load() +}