From 4c066d8c3257cb800f0aad09a1f53a37ebfa1686 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Thu, 6 Jun 2024 13:03:07 -0400 Subject: [PATCH] initialize messageCounter to 2 instead of verifying later (#1156) Clean up the messageCounter checks added in #1154. Instead of checking that messageCounter is still at 2, just initialize it to 2 and only increment for non-handshake messages. Handshake packets will always be packets 1 and 2. --- connection_state.go | 2 ++ handshake_ix.go | 11 ----------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/connection_state.go b/connection_state.go index 8ef8b3a..1dd3c8c 100644 --- a/connection_state.go +++ b/connection_state.go @@ -72,6 +72,8 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i window: b, myCert: certState.Certificate, } + // always start the counter from 2, as packet 1 and packet 2 are handshake packets. + ci.messageCounter.Add(2) return ci } diff --git a/handshake_ix.go b/handshake_ix.go index b86ecab..d0bee86 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -1,7 +1,6 @@ package nebula import ( - "fmt" "time" "github.com/flynn/noise" @@ -47,7 +46,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { } h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) - ci.messageCounter.Add(1) msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { @@ -322,10 +320,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by } f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) - prev := hostinfo.ConnectionState.messageCounter.Swap(2) - if prev > 2 { - panic(fmt.Errorf("invalid state: messageCounter > 2 before handshake complete: %v", prev)) - } hostinfo.remotes.ResetBlockedRemotes() @@ -468,11 +462,6 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha // Build up the radix for the firewall if we have subnets in the cert hostinfo.CreateRemoteCIDR(remoteCert) - prev := hostinfo.ConnectionState.messageCounter.Swap(2) - if prev > 2 { - panic(fmt.Errorf("invalid state: messageCounter > 2 before handshake complete: %v", prev)) - } - // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp f.handshakeManager.Complete(hostinfo, f) f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)