update EncReader and EncWriter interface function args to have concrete types (#844)

* Update LightHouseHandlerFunc to remove EncWriter param.
* Move EncWriter to interface
* EncReader, too
This commit is contained in:
brad-defined 2023-04-07 14:28:37 -04:00 committed by GitHub
parent 3cb4e0ef57
commit 9b03053191
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 69 additions and 54 deletions

View File

@ -5,7 +5,7 @@ import (
"github.com/slackhq/nebula/udp"
)
func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H, hostinfo *HostInfo) {
func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H, hostinfo *HostInfo) {
// First remote allow list check before we know the vpnIp
if addr != nil {
if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {

View File

@ -68,7 +68,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
hostinfo.handshakeStart = time.Now()
}
func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H) {
func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
@ -240,14 +240,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
}
return
} else {
via2 := via.(*ViaSender)
if via2 == nil {
if via == nil {
f.l.Error("Handshake send failed: both addr and via are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via2.relayHI.vpnIp).
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
return
@ -315,14 +314,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
Info("Handshake message sent")
}
} else {
via2 := via.(*ViaSender)
if via2 == nil {
if via == nil {
f.l.Error("Handshake send failed: both addr and via are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", vpnIp).WithField("relay", via2.relayHI.vpnIp).
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -338,7 +336,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
return
}
func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *HostInfo, packet []byte, h *header.H) bool {
func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool {
if hostinfo == nil {
// Nothing here to tear down, got a bogus stage 2 packet
return true
@ -482,8 +480,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *
if addr != nil {
hostinfo.SetRemote(addr)
} else {
via2 := via.(*ViaSender)
hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
}
// Build up the radix for the firewall if we have subnets in the cert

View File

@ -73,7 +73,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
}
}
func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
clockSource := time.NewTicker(c.config.tryInterval)
defer clockSource.Stop()
@ -89,7 +89,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
}
}
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) {
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
c.OutboundHandshakeTimer.Advance(now)
for {
vpnIp, has := c.OutboundHandshakeTimer.Purge()
@ -100,7 +100,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.E
}
}
func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) {
func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) {
hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
if err != nil {
return

View File

@ -84,7 +84,7 @@ func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
return
}
func (mw *mockEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) {
func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
return
}

View File

@ -248,16 +248,13 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C
// nb is a buffer used to store the nonce value, re-used for performance reasons.
// out is a buffer used to store the result of the Encrypt operation
// q indicates which writer to use to send the packet.
func (f *Interface) SendVia(viaIfc interface{},
relayIfc interface{},
func (f *Interface) SendVia(via *HostInfo,
relay *Relay,
ad,
nb,
out []byte,
nocopy bool,
) {
via := viaIfc.(*HostInfo)
relay := relayIfc.(*Relay)
if noiseutil.EncryptLockNeeded {
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
via.ConnectionState.writeLock.Lock()

View File

@ -16,6 +16,7 @@ import (
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
@ -89,6 +90,18 @@ type Interface struct {
l *logrus.Logger
}
type EncWriter interface {
SendVia(via *HostInfo,
relay *Relay,
ad,
nb,
out []byte,
nocopy bool,
)
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
Handshake(vpnIp iputil.VpnIp)
}
type sendRecvErrorConfig uint8
const (
@ -238,7 +251,7 @@ func (f *Interface) listenOut(i int) {
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i)
li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {

View File

@ -65,7 +65,7 @@ type LightHouse struct {
interval atomic.Int64
updateCancel context.CancelFunc
updateParentCtx context.Context
updateUdp udp.EncWriter
updateUdp EncWriter
nebulaPort uint32 // 32 bits because protobuf does not have a uint16
advertiseAddrs atomic.Pointer[[]netIpAndPort]
@ -382,7 +382,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
return nil
}
func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList {
if !lh.IsLighthouseIP(ip) {
lh.QueryServer(ip, f)
}
@ -396,7 +396,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
}
// This is asynchronous so no reply should be expected
func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) {
func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f EncWriter) {
if lh.amLighthouse {
return
}
@ -629,7 +629,7 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
}
func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
lh.updateParentCtx = ctx
lh.updateUdp = f
@ -655,7 +655,7 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
}
}
func (lh *LightHouse) SendUpdate(f udp.EncWriter) {
func (lh *LightHouse) SendUpdate(f EncWriter) {
var v4 []*Ip4AndPort
var v6 []*Ip6AndPort
@ -760,7 +760,13 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
return lhh.meta
}
func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) {
func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc {
return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) {
lhh.HandleRequest(rAddr, vpnIp, p, f)
}
}
func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) {
n := lhh.resetMeta()
err := n.Unmarshal(p)
if err != nil {
@ -795,7 +801,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp,
}
}
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) {
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) {
// Exit if we don't answer queries
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
@ -928,7 +934,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
am.Unlock()
}
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) {
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
if !lhh.lh.IsLighthouseIP(vpnIp) {
return
}

View File

@ -372,7 +372,7 @@ type testEncWriter struct {
metaFilter *NebulaMeta_MessageType
}
func (tw *testEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) {
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
}
func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) {
}

View File

@ -21,7 +21,23 @@ const (
minFwPacketLen = 4
)
func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
func readOutsidePackets(f *Interface) udp.EncReader {
return func(
addr *udp.Addr,
out []byte,
packet []byte,
header *header.H,
fwPacket *firewall.Packet,
lhh udp.LightHouseHandlerFunc,
nb []byte,
q int,
localCache firewall.ConntrackCache,
) {
f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
}
}
func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// TODO: best if we return this and let caller log
@ -149,7 +165,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by
return
}
lhf(addr, hostinfo.vpnIp, d, f)
lhf(addr, hostinfo.vpnIp, d)
// Fallthrough to the bottom to record incoming traffic

View File

@ -9,7 +9,6 @@ const MTU = 9001
type EncReader func(
addr *Addr,
via interface{},
out []byte,
packet []byte,
header *header.H,

View File

@ -1,22 +1,9 @@
package udp
import (
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
)
type EncWriter interface {
SendVia(via interface{},
relay interface{},
ad,
nb,
out []byte,
nocopy bool,
)
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
Handshake(vpnIp iputil.VpnIp)
}
//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter)
type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte)

View File

@ -86,6 +86,6 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
udpAddr.IP = rua.IP
udpAddr.Port = uint16(rua.Port)
r(udpAddr, nil, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
}
}

View File

@ -145,7 +145,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
for i := 0; i < n; i++ {
udpAddr.IP = names[i][8:24]
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
r(udpAddr, nil, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
}
}
}

View File

@ -122,7 +122,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
}
ua.Port = p.FromPort
copy(ua.IP, p.FromIp.To16())
r(ua, nil, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
}
}