From 9b030531910847a1fa448263a1cb1a41330299f9 Mon Sep 17 00:00:00 2001 From: brad-defined <77982333+brad-defined@users.noreply.github.com> Date: Fri, 7 Apr 2023 14:28:37 -0400 Subject: [PATCH] update EncReader and EncWriter interface function args to have concrete types (#844) * Update LightHouseHandlerFunc to remove EncWriter param. * Move EncWriter to interface * EncReader, too --- handshake.go | 2 +- handshake_ix.go | 25 +++++++++++-------------- handshake_manager.go | 6 +++--- handshake_manager_test.go | 2 +- inside.go | 7 ++----- interface.go | 15 ++++++++++++++- lighthouse.go | 22 ++++++++++++++-------- lighthouse_test.go | 2 +- outside.go | 20 ++++++++++++++++++-- udp/conn.go | 1 - udp/temp.go | 15 +-------------- udp/udp_generic.go | 2 +- udp/udp_linux.go | 2 +- udp/udp_tester.go | 2 +- 14 files changed, 69 insertions(+), 54 deletions(-) diff --git a/handshake.go b/handshake.go index 1cad0db..1f2f03a 100644 --- a/handshake.go +++ b/handshake.go @@ -5,7 +5,7 @@ import ( "github.com/slackhq/nebula/udp" ) -func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H, hostinfo *HostInfo) { +func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H, hostinfo *HostInfo) { // First remote allow list check before we know the vpnIp if addr != nil { if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { diff --git a/handshake_ix.go b/handshake_ix.go index a51fb31..b6b5658 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -68,7 +68,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { hostinfo.handshakeStart = time.Now() } -func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H) { +func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0) // Mark packet 1 as seen so it doesn't show up as missed ci.window.Update(f.l, 1) @@ -240,14 +240,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b } return } else { - via2 := via.(*ViaSender) - if via2 == nil { + if via == nil { f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp) - f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via2.relayHI.vpnIp). + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) + f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") return @@ -315,14 +314,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b Info("Handshake message sent") } } else { - via2 := via.(*ViaSender) - if via2 == nil { + if via == nil { f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp) - f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", vpnIp).WithField("relay", via2.relayHI.vpnIp). + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) + f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -338,7 +336,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b return } -func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *HostInfo, packet []byte, h *header.H) bool { +func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool { if hostinfo == nil { // Nothing here to tear down, got a bogus stage 2 packet return true @@ -482,8 +480,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo * if addr != nil { hostinfo.SetRemote(addr) } else { - via2 := via.(*ViaSender) - hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) } // Build up the radix for the firewall if we have subnets in the cert diff --git a/handshake_manager.go b/handshake_manager.go index c8a01ca..ce2811b 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -73,7 +73,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [ } } -func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) { +func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) { clockSource := time.NewTicker(c.config.tryInterval) defer clockSource.Stop() @@ -89,7 +89,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) { } } -func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) { +func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) { c.OutboundHandshakeTimer.Advance(now) for { vpnIp, has := c.OutboundHandshakeTimer.Purge() @@ -100,7 +100,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.E } } -func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) { +func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) { hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp) if err != nil { return diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 84b8ef6..3be8a1b 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -84,7 +84,7 @@ func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess return } -func (mw *mockEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) { +func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { return } diff --git a/inside.go b/inside.go index 457fcac..9c40251 100644 --- a/inside.go +++ b/inside.go @@ -248,16 +248,13 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C // nb is a buffer used to store the nonce value, re-used for performance reasons. // out is a buffer used to store the result of the Encrypt operation // q indicates which writer to use to send the packet. -func (f *Interface) SendVia(viaIfc interface{}, - relayIfc interface{}, +func (f *Interface) SendVia(via *HostInfo, + relay *Relay, ad, nb, out []byte, nocopy bool, ) { - via := viaIfc.(*HostInfo) - relay := relayIfc.(*Relay) - if noiseutil.EncryptLockNeeded { // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check via.ConnectionState.writeLock.Lock() diff --git a/interface.go b/interface.go index af83abc..e87f9f9 100644 --- a/interface.go +++ b/interface.go @@ -16,6 +16,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" @@ -89,6 +90,18 @@ type Interface struct { l *logrus.Logger } +type EncWriter interface { + SendVia(via *HostInfo, + relay *Relay, + ad, + nb, + out []byte, + nocopy bool, + ) + SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) + Handshake(vpnIp iputil.VpnIp) +} + type sendRecvErrorConfig uint8 const ( @@ -238,7 +251,7 @@ func (f *Interface) listenOut(i int) { lhh := f.lightHouse.NewRequestHandler() conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i) + li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { diff --git a/lighthouse.go b/lighthouse.go index 5b34a3e..d6b6a5f 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -65,7 +65,7 @@ type LightHouse struct { interval atomic.Int64 updateCancel context.CancelFunc updateParentCtx context.Context - updateUdp udp.EncWriter + updateUdp EncWriter nebulaPort uint32 // 32 bits because protobuf does not have a uint16 advertiseAddrs atomic.Pointer[[]netIpAndPort] @@ -382,7 +382,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return nil } -func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList { +func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList { if !lh.IsLighthouseIP(ip) { lh.QueryServer(ip, f) } @@ -396,7 +396,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList { } // This is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) { +func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f EncWriter) { if lh.amLighthouse { return } @@ -629,7 +629,7 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr { return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) } -func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) { +func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) { lh.updateParentCtx = ctx lh.updateUdp = f @@ -655,7 +655,7 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) { } } -func (lh *LightHouse) SendUpdate(f udp.EncWriter) { +func (lh *LightHouse) SendUpdate(f EncWriter) { var v4 []*Ip4AndPort var v6 []*Ip6AndPort @@ -760,7 +760,13 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { return lhh.meta } -func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) { +func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { + return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) { + lhh.HandleRequest(rAddr, vpnIp, p, f) + } +} + +func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { @@ -795,7 +801,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -928,7 +934,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp am.Unlock() } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { if !lhh.lh.IsLighthouseIP(vpnIp) { return } diff --git a/lighthouse_test.go b/lighthouse_test.go index e5a1692..1824463 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -372,7 +372,7 @@ type testEncWriter struct { metaFilter *NebulaMeta_MessageType } -func (tw *testEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) { +func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { } func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) { } diff --git a/outside.go b/outside.go index fd6f0a3..8361ce3 100644 --- a/outside.go +++ b/outside.go @@ -21,7 +21,23 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { +func readOutsidePackets(f *Interface) udp.EncReader { + return func( + addr *udp.Addr, + out []byte, + packet []byte, + header *header.H, + fwPacket *firewall.Packet, + lhh udp.LightHouseHandlerFunc, + nb []byte, + q int, + localCache firewall.ConntrackCache, + ) { + f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache) + } +} + +func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // TODO: best if we return this and let caller log @@ -149,7 +165,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by return } - lhf(addr, hostinfo.vpnIp, d, f) + lhf(addr, hostinfo.vpnIp, d) // Fallthrough to the bottom to record incoming traffic diff --git a/udp/conn.go b/udp/conn.go index fa52fe5..f967a9a 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -9,7 +9,6 @@ const MTU = 9001 type EncReader func( addr *Addr, - via interface{}, out []byte, packet []byte, header *header.H, diff --git a/udp/temp.go b/udp/temp.go index 5cc8c1c..2efe31d 100644 --- a/udp/temp.go +++ b/udp/temp.go @@ -1,22 +1,9 @@ package udp import ( - "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" ) -type EncWriter interface { - SendVia(via interface{}, - relay interface{}, - ad, - nb, - out []byte, - nocopy bool, - ) - SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) - Handshake(vpnIp iputil.VpnIp) -} - //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare -type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) +type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte) diff --git a/udp/udp_generic.go b/udp/udp_generic.go index f03174d..ff254eb 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -86,6 +86,6 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall udpAddr.IP = rua.IP udpAddr.Port = uint16(rua.Port) - r(udpAddr, nil, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 77102ab..26bbe36 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -145,7 +145,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall for i := 0; i < n; i++ { udpAddr.IP = names[i][8:24] udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) - r(udpAddr, nil, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 3b33f0d..8b5e531 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -122,7 +122,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall } ua.Port = p.FromPort copy(ua.IP, p.FromIp.To16()) - r(ua, nil, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) } }