diff --git a/connection_state.go b/connection_state.go index f8c31f6..8ef8b3a 100644 --- a/connection_state.go +++ b/connection_state.go @@ -24,7 +24,6 @@ type ConnectionState struct { messageCounter atomic.Uint64 window *Bits writeLock sync.Mutex - ready bool } func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { @@ -71,7 +70,6 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i H: hs, initiator: initiator, window: b, - ready: false, myCert: certState.Certificate, } @@ -83,6 +81,5 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) { "certificate": cs.peerCert, "initiator": cs.initiator, "message_counter": cs.messageCounter.Load(), - "ready": cs.ready, }) } diff --git a/control.go b/control.go index 4af115c..13b2658 100644 --- a/control.go +++ b/control.go @@ -41,7 +41,6 @@ type ControlHostInfo struct { LocalIndex uint32 `json:"localIndex"` RemoteIndex uint32 `json:"remoteIndex"` RemoteAddrs []*udp.Addr `json:"remoteAddrs"` - CachedPackets int `json:"cachedPackets"` Cert *cert.NebulaCertificate `json:"cert"` MessageCounter uint64 `json:"messageCounter"` CurrentRemote *udp.Addr `json:"currentRemote"` @@ -234,7 +233,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), - CachedPackets: len(h.packetStore), CurrentRelaysToMe: h.relayState.CopyRelayIps(), CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(), } diff --git a/control_test.go b/control_test.go index 56a2b2f..847332b 100644 --- a/control_test.go +++ b/control_test.go @@ -96,7 +96,6 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { LocalIndex: 201, RemoteIndex: 200, RemoteAddrs: []*udp.Addr{remote2, remote1}, - CachedPackets: 0, Cert: crt.Copy(), MessageCounter: 0, CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444), @@ -105,7 +104,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { } // Make sure we don't have any unexpected fields - assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) + assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet diff --git a/handshake.go b/handshake.go deleted file mode 100644 index 8cfba21..0000000 --- a/handshake.go +++ /dev/null @@ -1,31 +0,0 @@ -package nebula - -import ( - "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/udp" -) - -func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H, hostinfo *HostInfo) { - // First remote allow list check before we know the vpnIp - if addr != nil { - if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { - f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") - return - } - } - - switch h.Subtype { - case header.HandshakeIXPSK0: - switch h.MessageCounter { - case 1: - ixHandshakeStage1(f, addr, via, packet, h) - case 2: - newHostinfo := f.handshakeManager.QueryIndex(h.RemoteIndex) - tearDown := ixHandshakeStage2(f, addr, via, newHostinfo, packet, h) - if tearDown && newHostinfo != nil { - f.handshakeManager.DeleteHostInfo(newHostinfo) - } - } - } - -} diff --git a/handshake_ix.go b/handshake_ix.go index 7e60c79..1905c00 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -4,6 +4,7 @@ import ( "time" "github.com/flynn/noise" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" @@ -13,20 +14,20 @@ import ( // This function constructs a handshake packet, but does not actually send it // Sending is done by the handshake manager -func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool { - err := f.handshakeManager.allocateIndex(hostinfo) +func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { + err := f.handshakeManager.allocateIndex(hh) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") return false } certState := f.pki.GetCertState() ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0) - hostinfo.ConnectionState = ci + hh.hostinfo.ConnectionState = ci hsProto := &NebulaHandshakeDetails{ - InitiatorIndex: hostinfo.localIndexId, + InitiatorIndex: hh.hostinfo.localIndexId, Time: uint64(time.Now().UnixNano()), Cert: certState.RawCertificateNoKey, } @@ -39,7 +40,7 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool { hsBytes, err = hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") return false } @@ -49,7 +50,7 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool { msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return false } @@ -58,9 +59,8 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool { // handshake packet 1 from the responder ci.window.Update(f.l, 1) - hostinfo.HandshakePacket[0] = msg - hostinfo.HandshakeReady = true - hostinfo.handshakeStart = time.Now() + hh.hostinfo.HandshakePacket[0] = msg + hh.ready = true return true } @@ -140,9 +140,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by }, } - hostinfo.Lock() - defer hostinfo.Unlock() - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -208,19 +205,12 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by if err != nil { switch err { case ErrAlreadySeen: - // Update remote if preferred (Note we have to switch to locking - // the existing hostinfo, and then switch back so the defer Unlock - // higher in this function still works) - hostinfo.Unlock() - existing.Lock() // Update remote if preferred if existing.SetRemoteIfPreferred(f.hostMap, addr) { // Send a test packet to ensure the other side has also switched to // the preferred remote f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) } - existing.Unlock() - hostinfo.Lock() msg = existing.HandshakePacket[2] f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) @@ -307,7 +297,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("sentCachedPackets", len(hostinfo.packetStore)). Info("Handshake message sent") } } else { @@ -323,25 +312,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithField("sentCachedPackets", len(hostinfo.packetStore)). Info("Handshake message sent") } f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) - hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics) + hostinfo.ConnectionState.messageCounter.Store(2) + hostinfo.remotes.ResetBlockedRemotes() return } -func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool { - if hostinfo == nil { +func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { + if hh == nil { // Nothing here to tear down, got a bogus stage 2 packet return true } - hostinfo.Lock() - defer hostinfo.Unlock() + hh.Lock() + defer hh.Unlock() + hostinfo := hh.hostinfo if addr != nil { if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") @@ -350,22 +340,6 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H } ci := hostinfo.ConnectionState - if ci.ready { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). - Info("Handshake is already complete") - - // Update remote if preferred - if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) { - // Send a test packet to ensure the other side has also switched to - // the preferred remote - f.SendMessageToVpnIp(header.Test, header.TestRequest, hostinfo.vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - } - - // We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets - return false - } - msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). @@ -422,22 +396,22 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H f.handshakeManager.DeleteHostInfo(hostinfo) // Create a new hostinfo/handshake for the intended vpn ip - f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHostInfo *HostInfo) { + f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) { //TODO: this doesnt know if its being added or is being used for caching a packet // Block the current used address - newHostInfo.remotes = hostinfo.remotes - newHostInfo.remotes.BlockRemote(addr) + newHH.hostinfo.remotes = hostinfo.remotes + newHH.hostinfo.remotes.BlockRemote(addr) // Get the correct remote list for the host we did handshake with hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) - f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). - WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). + f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). + WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). Info("Blocked addresses for handshakes") // Swap the packet store to benefit the original intended recipient - newHostInfo.packetStore = hostinfo.packetStore - hostinfo.packetStore = []*cachedPacket{} + newHH.packetStore = hh.packetStore + hh.packetStore = []*cachedPacket{} // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down hostinfo.vpnIp = vpnIp @@ -450,7 +424,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H // Mark packet 2 as seen so it doesn't show up as missed ci.window.Update(f.l, 2) - duration := time.Since(hostinfo.handshakeStart).Nanoseconds() + duration := time.Since(hh.startTime).Nanoseconds() f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -458,7 +432,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("durationNs", duration). - WithField("sentCachedPackets", len(hostinfo.packetStore)). + WithField("sentCachedPackets", len(hh.packetStore)). Info("Handshake message received") hostinfo.remoteIndexId = hs.Details.ResponderIndex @@ -482,7 +456,23 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp f.handshakeManager.Complete(hostinfo, f) f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) - hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics) + + hostinfo.ConnectionState.messageCounter.Store(2) + + if f.l.Level >= logrus.DebugLevel { + hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) + } + + if len(hh.packetStore) > 0 { + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + for _, cp := range hh.packetStore { + cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) + } + f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) + } + + hostinfo.remotes.ResetBlockedRemotes() f.metricHandshakes.Update(duration) return false diff --git a/handshake_manager.go b/handshake_manager.go index 11c0c6f..00321d6 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -46,8 +46,8 @@ type HandshakeManager struct { // Mutex for interacting with the vpnIps and indexes maps sync.RWMutex - vpnIps map[iputil.VpnIp]*HostInfo - indexes map[uint32]*HostInfo + vpnIps map[iputil.VpnIp]*HandshakeHostInfo + indexes map[uint32]*HandshakeHostInfo mainHostMap *HostMap lightHouse *LightHouse @@ -64,10 +64,47 @@ type HandshakeManager struct { trigger chan iputil.VpnIp } +type HandshakeHostInfo struct { + sync.Mutex + + startTime time.Time // Time that we first started trying with this handshake + ready bool // Is the handshake ready + counter int // How many attempts have we made so far + lastRemotes []*udp.Addr // Remotes that we sent to during the previous attempt + packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes + + hostinfo *HostInfo +} + +func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { + if len(hh.packetStore) < 100 { + tempPacket := make([]byte, len(packet)) + copy(tempPacket, packet) + + hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket}) + if l.Level >= logrus.DebugLevel { + hh.hostinfo.logger(l). + WithField("length", len(hh.packetStore)). + WithField("stored", true). + Debugf("Packet store") + } + + } else { + m.dropped.Inc(1) + + if l.Level >= logrus.DebugLevel { + hh.hostinfo.logger(l). + WithField("length", len(hh.packetStore)). + WithField("stored", false). + Debugf("Packet store") + } + } +} + func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ - vpnIps: map[iputil.VpnIp]*HostInfo{}, - indexes: map[uint32]*HostInfo{}, + vpnIps: map[iputil.VpnIp]*HandshakeHostInfo{}, + indexes: map[uint32]*HandshakeHostInfo{}, mainHostMap: mainHostMap, lightHouse: lightHouse, outside: outside, @@ -97,6 +134,31 @@ func (c *HandshakeManager) Run(ctx context.Context) { } } +func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { + // First remote allow list check before we know the vpnIp + if addr != nil { + if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { + hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + return + } + } + + switch h.Subtype { + case header.HandshakeIXPSK0: + switch h.MessageCounter { + case 1: + ixHandshakeStage1(hm.f, addr, via, packet, h) + + case 2: + newHostinfo := hm.queryIndex(h.RemoteIndex) + tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h) + if tearDown && newHostinfo != nil { + hm.DeleteHostInfo(newHostinfo.hostinfo) + } + } + } +} + func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { c.OutboundHandshakeTimer.Advance(now) for { @@ -108,41 +170,35 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { } } -func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) { - hostinfo := c.QueryVpnIp(vpnIp) - if hostinfo == nil { - return - } - hostinfo.Lock() - defer hostinfo.Unlock() - - // We may have raced to completion but now that we have a lock we should ensure we have not yet completed. - if hostinfo.HandshakeComplete { - // Ensure we don't exist in the pending hostmap anymore since we have completed - c.DeleteHostInfo(hostinfo) +func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) { + hh := hm.queryVpnIp(vpnIp) + if hh == nil { return } + hh.Lock() + defer hh.Unlock() + hostinfo := hh.hostinfo // If we are out of time, clean up - if hostinfo.HandshakeCounter >= c.config.retries { - hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)). - WithField("initiatorIndex", hostinfo.localIndexId). - WithField("remoteIndex", hostinfo.remoteIndexId). + if hh.counter >= hm.config.retries { + hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)). + WithField("initiatorIndex", hh.hostinfo.localIndexId). + WithField("remoteIndex", hh.hostinfo.remoteIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()). + WithField("durationNs", time.Since(hh.startTime).Nanoseconds()). Info("Handshake timed out") - c.metricTimedOut.Inc(1) - c.DeleteHostInfo(hostinfo) + hm.metricTimedOut.Inc(1) + hm.DeleteHostInfo(hostinfo) return } // Increment the counter to increase our delay, linear backoff - hostinfo.HandshakeCounter++ + hh.counter++ // Check if we have a handshake packet to transmit yet - if !hostinfo.HandshakeReady { - if !ixHandshakeStage0(c.f, hostinfo) { - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) + if !hh.ready { + if !ixHandshakeStage0(hm.f, hh) { + hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter)) return } } @@ -152,11 +208,11 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere // NB ^ This comment doesn't jive. It's how the thing gets initialized. // It's the common path. Should it update every time, in case a future LH query/queries give us more info? if hostinfo.remotes == nil { - hostinfo.remotes = c.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp) } - remotes := hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges) - remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes) + remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges) + remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes) // We only care about a lighthouse trigger if we have new remotes to send to. // This is a very specific optimization for a fast lighthouse reply. @@ -165,25 +221,25 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere return } - hostinfo.HandshakeLastRemotes = remotes + hh.lastRemotes = remotes // TODO: this will generate a load of queries for hosts with only 1 ip // (such as ones registered to the lighthouse with only a private IP) // So we only do it one time after attempting 5 handshakes already. - if len(remotes) <= 1 && hostinfo.HandshakeCounter == 5 { + if len(remotes) <= 1 && hh.counter == 5 { // If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse // Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about // the learned public ip for them. Query again to short circuit the promotion counter - c.lightHouse.QueryServer(vpnIp, c.f) + hm.lightHouse.QueryServer(vpnIp, hm.f) } // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply var sentTo []*udp.Addr - hostinfo.remotes.ForEach(c.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { - c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) - err := c.outside.WriteTo(hostinfo.HandshakePacket[0], addr) + hostinfo.remotes.ForEach(hm.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { + hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) + err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { - hostinfo.logger(c.l).WithField("udpAddr", addr). + hostinfo.logger(hm.l).WithField("udpAddr", addr). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithError(err).Error("Failed to send handshake message") @@ -196,63 +252,63 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout, // so only log when the list of remotes has changed if remotesHaveChanged { - hostinfo.logger(c.l).WithField("udpAddrs", sentTo). + hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("Handshake message sent") - } else if c.l.IsLevelEnabled(logrus.DebugLevel) { - hostinfo.logger(c.l).WithField("udpAddrs", sentTo). + } else if hm.l.IsLevelEnabled(logrus.DebugLevel) { + hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Debug("Handshake message sent") } - if c.config.useRelays && len(hostinfo.remotes.relays) > 0 { - hostinfo.logger(c.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") + if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 { + hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { // Don't relay to myself, and don't relay through the host I'm trying to connect to - if *relay == vpnIp || *relay == c.lightHouse.myVpnIp { + if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp { continue } - relayHostInfo := c.mainHostMap.QueryVpnIp(*relay) + relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay) if relayHostInfo == nil || relayHostInfo.remote == nil { - hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") - c.f.Handshake(*relay) + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") + hm.f.Handshake(*relay) continue } // Check the relay HostInfo to see if we already established a relay through it if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok { switch existingRelay.State { case Established: - hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Send handshake via relay") - c.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") + hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) case Requested: - hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") // Re-send the CreateRelay request, in case the previous one was lost. m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: existingRelay.LocalIndex, - RelayFromIp: uint32(c.lightHouse.myVpnIp), + RelayFromIp: uint32(hm.lightHouse.myVpnIp), RelayToIp: uint32(vpnIp), } msg, err := m.Marshal() if err != nil { - hostinfo.logger(c.l). + hostinfo.logger(hm.l). WithError(err). Error("Failed to marshal Control message to create relay") } else { // This must send over the hostinfo, not over hm.Hosts[ip] - c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - c.l.WithFields(logrus.Fields{ - "relayFrom": c.lightHouse.myVpnIp, + hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + hm.l.WithFields(logrus.Fields{ + "relayFrom": hm.lightHouse.myVpnIp, "relayTo": vpnIp, "initiatorRelayIndex": existingRelay.LocalIndex, "relay": *relay}). Info("send CreateRelayRequest") } default: - hostinfo.logger(c.l). + hostinfo.logger(hm.l). WithField("vpnIp", vpnIp). WithField("state", existingRelay.State). WithField("relay", relayHostInfo.vpnIp). @@ -261,26 +317,26 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere } else { // No relays exist or requested yet. if relayHostInfo.remote != nil { - idx, err := AddRelay(c.l, relayHostInfo, c.mainHostMap, vpnIp, nil, TerminalType, Requested) + idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) if err != nil { - hostinfo.logger(c.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") + hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") } m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: idx, - RelayFromIp: uint32(c.lightHouse.myVpnIp), + RelayFromIp: uint32(hm.lightHouse.myVpnIp), RelayToIp: uint32(vpnIp), } msg, err := m.Marshal() if err != nil { - hostinfo.logger(c.l). + hostinfo.logger(hm.l). WithError(err). Error("Failed to marshal Control message to create relay") } else { - c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - c.l.WithFields(logrus.Fields{ - "relayFrom": c.lightHouse.myVpnIp, + hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + hm.l.WithFields(logrus.Fields{ + "relayFrom": hm.lightHouse.myVpnIp, "relayTo": vpnIp, "initiatorRelayIndex": idx, "relay": *relay}). @@ -293,13 +349,13 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add if !lighthouseTriggered { - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter)) + hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter)) } } // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present // The 2nd argument will be true if the hostinfo is ready to transmit traffic -func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) (*HostInfo, bool) { +func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { // Check the main hostmap and maintain a read lock if our host is not there hm.mainHostMap.RLock() if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok { @@ -316,16 +372,16 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Hos } // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip -func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) *HostInfo { +func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo { hm.Lock() - if hostinfo, ok := hm.vpnIps[vpnIp]; ok { + if hh, ok := hm.vpnIps[vpnIp]; ok { // We are already trying to handshake with this vpn ip if cacheCb != nil { - cacheCb(hostinfo) + cacheCb(hh) } hm.Unlock() - return hostinfo + return hh.hostinfo } hostinfo := &HostInfo{ @@ -338,12 +394,16 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Hos }, } - hm.vpnIps[vpnIp] = hostinfo + hh := &HandshakeHostInfo{ + hostinfo: hostinfo, + startTime: time.Now(), + } + hm.vpnIps[vpnIp] = hh hm.metricInitiated.Inc(1) hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval) if cacheCb != nil { - cacheCb(hostinfo) + cacheCb(hh) } // If this is a static host, we don't need to wait for the HostQueryReply @@ -416,8 +476,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket return existingIndex, ErrLocalIndexCollision } - existingIndex, found = c.indexes[hostinfo.localIndexId] - if found && existingIndex != hostinfo { + existingPendingIndex, found := c.indexes[hostinfo.localIndexId] + if found && existingPendingIndex.hostinfo != hostinfo { // We have a collision, but for a different hostinfo return existingIndex, ErrLocalIndexCollision } @@ -461,7 +521,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { // allocateIndex generates a unique localIndexId for this HostInfo // and adds it to the pendingHostMap. Will error if we are unable to generate // a unique localIndexId -func (hm *HandshakeManager) allocateIndex(h *HostInfo) error { +func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { hm.mainHostMap.RLock() defer hm.mainHostMap.RUnlock() hm.Lock() @@ -477,8 +537,8 @@ func (hm *HandshakeManager) allocateIndex(h *HostInfo) error { _, inMain := hm.mainHostMap.Indexes[index] if !inMain && !inPending { - h.localIndexId = index - hm.indexes[index] = h + hh.hostinfo.localIndexId = index + hm.indexes[index] = hh return nil } } @@ -495,12 +555,12 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { delete(c.vpnIps, hostinfo.vpnIp) if len(c.vpnIps) == 0 { - c.vpnIps = map[iputil.VpnIp]*HostInfo{} + c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{} } delete(c.indexes, hostinfo.localIndexId) if len(c.vpnIps) == 0 { - c.indexes = map[uint32]*HostInfo{} + c.indexes = map[uint32]*HandshakeHostInfo{} } if c.l.Level >= logrus.DebugLevel { @@ -510,16 +570,33 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { } } -func (c *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { - c.RLock() - defer c.RUnlock() - return c.vpnIps[vpnIp] +func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { + hh := hm.queryVpnIp(vpnIp) + if hh != nil { + return hh.hostinfo + } + return nil + } -func (c *HandshakeManager) QueryIndex(index uint32) *HostInfo { - c.RLock() - defer c.RUnlock() - return c.indexes[index] +func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo { + hm.RLock() + defer hm.RUnlock() + return hm.vpnIps[vpnIp] +} + +func (hm *HandshakeManager) QueryIndex(index uint32) *HostInfo { + hh := hm.queryIndex(index) + if hh != nil { + return hh.hostinfo + } + return nil +} + +func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo { + hm.RLock() + defer hm.RUnlock() + return hm.indexes[index] } func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet { @@ -531,7 +608,7 @@ func (c *HandshakeManager) ForEachVpnIp(f controlEach) { defer c.RUnlock() for _, v := range c.vpnIps { - f(v) + f(v.hostinfo) } } @@ -540,7 +617,7 @@ func (c *HandshakeManager) ForEachIndex(f controlEach) { defer c.RUnlock() for _, v := range c.indexes { - f(v) + f(v.hostinfo) } } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index d318a9d..303aa50 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" @@ -21,7 +22,16 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { mainHM := NewHostMap(l, vpncidr, preferredRanges) lh := newTestLighthouse() + cs := &CertState{ + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, + } + blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) + blah.f = &Interface{handshakeManager: blah, pki: &PKI{}, l: l} + blah.f.pki.cs.Store(cs) now := time.Now() blah.NextOutboundHandshakeTimerTick(now) @@ -31,7 +41,6 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.Same(t, i, i2) i.remotes = NewRemoteList(nil) - i.HandshakeReady = true // Adding something to pending should not affect the main hostmap assert.Len(t, mainHM.Hosts, 0) diff --git a/hostmap.go b/hostmap.go index f2618c7..4358632 100644 --- a/hostmap.go +++ b/hostmap.go @@ -21,6 +21,7 @@ const defaultPromoteEvery = 1000 // Count of packets sent before we try mo const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery const MaxRemotes = 10 +const maxRecvError = 4 // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip // 5 allows for an initial handshake and each host pair re-handshaking twice @@ -196,25 +197,20 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { } type HostInfo struct { - sync.RWMutex + remote *udp.Addr + remotes *RemoteList + promoteCounter atomic.Uint32 + ConnectionState *ConnectionState + remoteIndexId uint32 + localIndexId uint32 + vpnIp iputil.VpnIp + recvError atomic.Uint32 + remoteCidr *cidr.Tree4 + relayState RelayState - remote *udp.Addr - remotes *RemoteList - promoteCounter atomic.Uint32 - ConnectionState *ConnectionState - handshakeStart time.Time //todo: this an entry in the handshake manager - HandshakeReady bool //todo: being in the manager means you are ready - HandshakeCounter int //todo: another handshake manager entry - HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time - HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready - HandshakePacket map[uint8][]byte - packetStore []*cachedPacket //todo: this is other handshake manager entry - remoteIndexId uint32 - localIndexId uint32 - vpnIp iputil.VpnIp - recvError int - remoteCidr *cidr.Tree4 - relayState RelayState + // HandshakePacket records the packets used to create this hostinfo + // We need these to avoid replayed handshake packets creating new hostinfos which causes churn + HandshakePacket map[uint8][]byte // nextLHQuery is the earliest we can ask the lighthouse for new information. // This is used to limit lighthouse re-queries in chatty clients @@ -412,7 +408,6 @@ func (hm *HostMap) QueryIndex(index uint32) *HostInfo { } func (hm *HostMap) QueryRelayIndex(index uint32) *HostInfo { - //TODO: we probably just want to return bool instead of error, or at least a static error hm.RLock() if h, ok := hm.Relays[index]; ok { hm.RUnlock() @@ -535,10 +530,7 @@ func (hm *HostMap) ForEachIndex(f controlEach) { func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { c := i.promoteCounter.Add(1) if c%ifce.tryPromoteEvery.Load() == 0 { - // The lock here is currently protecting i.remote access - i.RLock() remote := i.remote - i.RUnlock() // return early if we are already on a preferred remote if remote != nil { @@ -573,58 +565,6 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) } } -func (i *HostInfo) unlockedCachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { - //TODO: return the error so we can log with more context - if len(i.packetStore) < 100 { - tempPacket := make([]byte, len(packet)) - copy(tempPacket, packet) - //l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket) - i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket}) - if l.Level >= logrus.DebugLevel { - i.logger(l). - WithField("length", len(i.packetStore)). - WithField("stored", true). - Debugf("Packet store") - } - - } else if l.Level >= logrus.DebugLevel { - m.dropped.Inc(1) - i.logger(l). - WithField("length", len(i.packetStore)). - WithField("stored", false). - Debugf("Packet store") - } -} - -// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets -func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) { - //TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because: - //TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send - //TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical - - i.HandshakeComplete = true - //TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen. - // Clamping it to 2 gets us out of the woods for now - i.ConnectionState.messageCounter.Store(2) - - if l.Level >= logrus.DebugLevel { - i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore)) - } - - if len(i.packetStore) > 0 { - nb := make([]byte, 12, 12) - out := make([]byte, mtu) - for _, cp := range i.packetStore { - cp.callback(cp.messageType, cp.messageSubType, i, cp.packet, nb, out) - } - m.sent.Inc(int64(len(i.packetStore))) - } - - i.remotes.ResetBlockedRemotes() - i.packetStore = make([]*cachedPacket, 0) - i.ConnectionState.ready = true -} - func (i *HostInfo) GetCert() *cert.NebulaCertificate { if i.ConnectionState != nil { return i.ConnectionState.peerCert @@ -681,9 +621,8 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { } func (i *HostInfo) RecvErrorExceeded() bool { - if i.recvError < 3 { - i.recvError += 1 - return false + if i.recvError.Add(1) >= maxRecvError { + return true } return true } diff --git a/inside.go b/inside.go index 2219d2b..9250b5e 100644 --- a/inside.go +++ b/inside.go @@ -44,8 +44,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(h *HostInfo) { - h.unlockedCachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) }) if hostinfo == nil { @@ -108,7 +108,7 @@ func (f *Interface) Handshake(vpnIp iputil.VpnIp) { // getOrHandshake returns nil if the vpnIp is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel -func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(info *HostInfo)) (*HostInfo, bool) { +func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) { vpnIp = f.inside.RouteFor(vpnIp) if vpnIp == 0 { @@ -143,8 +143,8 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { - hostInfo, ready := f.getOrHandshake(vpnIp, func(h *HostInfo) { - h.unlockedCachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) + hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) if hostInfo == nil { diff --git a/outside.go b/outside.go index a9dcdc8..4139830 100644 --- a/outside.go +++ b/outside.go @@ -198,7 +198,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case header.Handshake: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - HandleIncomingHandshake(f, addr, via, packet, h, hostinfo) + f.handshakeManager.HandleIncoming(addr, via, packet, h) return case header.RecvError: @@ -455,9 +455,6 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { return } - hostinfo.Lock() - defer hostinfo.Unlock() - if !hostinfo.RecvErrorExceeded() { return }