update EncReader and EncWriter interface function args to have concrete types (#844)

* Update LightHouseHandlerFunc to remove EncWriter param.
* Move EncWriter to interface
* EncReader, too
This commit is contained in:
brad-defined 2023-04-07 14:28:37 -04:00 committed by GitHub
parent 3cb4e0ef57
commit 9b03053191
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 69 additions and 54 deletions

View File

@ -5,7 +5,7 @@ import (
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H, hostinfo *HostInfo) { func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H, hostinfo *HostInfo) {
// First remote allow list check before we know the vpnIp // First remote allow list check before we know the vpnIp
if addr != nil { if addr != nil {
if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {

View File

@ -68,7 +68,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
hostinfo.handshakeStart = time.Now() hostinfo.handshakeStart = time.Now()
} }
func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H) { func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0) ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed // Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1) ci.window.Update(f.l, 1)
@ -240,14 +240,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
} }
return return
} else { } else {
via2 := via.(*ViaSender) if via == nil {
if via2 == nil {
f.l.Error("Handshake send failed: both addr and via are nil.") f.l.Error("Handshake send failed: both addr and via are nil.")
return return
} }
hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via2.relayHI.vpnIp). f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent") Info("Handshake message sent")
return return
@ -315,14 +314,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
Info("Handshake message sent") Info("Handshake message sent")
} }
} else { } else {
via2 := via.(*ViaSender) if via == nil {
if via2 == nil {
f.l.Error("Handshake send failed: both addr and via are nil.") f.l.Error("Handshake send failed: both addr and via are nil.")
return return
} }
hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", vpnIp).WithField("relay", via2.relayHI.vpnIp). f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -338,7 +336,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
return return
} }
func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *HostInfo, packet []byte, h *header.H) bool { func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool {
if hostinfo == nil { if hostinfo == nil {
// Nothing here to tear down, got a bogus stage 2 packet // Nothing here to tear down, got a bogus stage 2 packet
return true return true
@ -482,8 +480,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *
if addr != nil { if addr != nil {
hostinfo.SetRemote(addr) hostinfo.SetRemote(addr)
} else { } else {
via2 := via.(*ViaSender) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
} }
// Build up the radix for the firewall if we have subnets in the cert // Build up the radix for the firewall if we have subnets in the cert

View File

@ -73,7 +73,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
} }
} }
func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) { func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
clockSource := time.NewTicker(c.config.tryInterval) clockSource := time.NewTicker(c.config.tryInterval)
defer clockSource.Stop() defer clockSource.Stop()
@ -89,7 +89,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
} }
} }
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) { func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
c.OutboundHandshakeTimer.Advance(now) c.OutboundHandshakeTimer.Advance(now)
for { for {
vpnIp, has := c.OutboundHandshakeTimer.Purge() vpnIp, has := c.OutboundHandshakeTimer.Purge()
@ -100,7 +100,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.E
} }
} }
func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) { func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) {
hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp) hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
if err != nil { if err != nil {
return return

View File

@ -84,7 +84,7 @@ func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
return return
} }
func (mw *mockEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) { func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
return return
} }

View File

@ -248,16 +248,13 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C
// nb is a buffer used to store the nonce value, re-used for performance reasons. // nb is a buffer used to store the nonce value, re-used for performance reasons.
// out is a buffer used to store the result of the Encrypt operation // out is a buffer used to store the result of the Encrypt operation
// q indicates which writer to use to send the packet. // q indicates which writer to use to send the packet.
func (f *Interface) SendVia(viaIfc interface{}, func (f *Interface) SendVia(via *HostInfo,
relayIfc interface{}, relay *Relay,
ad, ad,
nb, nb,
out []byte, out []byte,
nocopy bool, nocopy bool,
) { ) {
via := viaIfc.(*HostInfo)
relay := relayIfc.(*Relay)
if noiseutil.EncryptLockNeeded { if noiseutil.EncryptLockNeeded {
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
via.ConnectionState.writeLock.Lock() via.ConnectionState.writeLock.Lock()

View File

@ -16,6 +16,7 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
@ -89,6 +90,18 @@ type Interface struct {
l *logrus.Logger l *logrus.Logger
} }
type EncWriter interface {
SendVia(via *HostInfo,
relay *Relay,
ad,
nb,
out []byte,
nocopy bool,
)
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
Handshake(vpnIp iputil.VpnIp)
}
type sendRecvErrorConfig uint8 type sendRecvErrorConfig uint8
const ( const (
@ -238,7 +251,7 @@ func (f *Interface) listenOut(i int) {
lhh := f.lightHouse.NewRequestHandler() lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i) li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {

View File

@ -65,7 +65,7 @@ type LightHouse struct {
interval atomic.Int64 interval atomic.Int64
updateCancel context.CancelFunc updateCancel context.CancelFunc
updateParentCtx context.Context updateParentCtx context.Context
updateUdp udp.EncWriter updateUdp EncWriter
nebulaPort uint32 // 32 bits because protobuf does not have a uint16 nebulaPort uint32 // 32 bits because protobuf does not have a uint16
advertiseAddrs atomic.Pointer[[]netIpAndPort] advertiseAddrs atomic.Pointer[[]netIpAndPort]
@ -382,7 +382,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
return nil return nil
} }
func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList { func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList {
if !lh.IsLighthouseIP(ip) { if !lh.IsLighthouseIP(ip) {
lh.QueryServer(ip, f) lh.QueryServer(ip, f)
} }
@ -396,7 +396,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
} }
// This is asynchronous so no reply should be expected // This is asynchronous so no reply should be expected
func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) { func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f EncWriter) {
if lh.amLighthouse { if lh.amLighthouse {
return return
} }
@ -629,7 +629,7 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
} }
func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) { func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
lh.updateParentCtx = ctx lh.updateParentCtx = ctx
lh.updateUdp = f lh.updateUdp = f
@ -655,7 +655,7 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
} }
} }
func (lh *LightHouse) SendUpdate(f udp.EncWriter) { func (lh *LightHouse) SendUpdate(f EncWriter) {
var v4 []*Ip4AndPort var v4 []*Ip4AndPort
var v6 []*Ip6AndPort var v6 []*Ip6AndPort
@ -760,7 +760,13 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
return lhh.meta return lhh.meta
} }
func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) { func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc {
return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) {
lhh.HandleRequest(rAddr, vpnIp, p, f)
}
}
func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) {
n := lhh.resetMeta() n := lhh.resetMeta()
err := n.Unmarshal(p) err := n.Unmarshal(p)
if err != nil { if err != nil {
@ -795,7 +801,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp,
} }
} }
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) { func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) {
// Exit if we don't answer queries // Exit if we don't answer queries
if !lhh.lh.amLighthouse { if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel { if lhh.l.Level >= logrus.DebugLevel {
@ -928,7 +934,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
am.Unlock() am.Unlock()
} }
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) { func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
if !lhh.lh.IsLighthouseIP(vpnIp) { if !lhh.lh.IsLighthouseIP(vpnIp) {
return return
} }

View File

@ -372,7 +372,7 @@ type testEncWriter struct {
metaFilter *NebulaMeta_MessageType metaFilter *NebulaMeta_MessageType
} }
func (tw *testEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) { func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
} }
func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) { func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) {
} }

View File

@ -21,7 +21,23 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { func readOutsidePackets(f *Interface) udp.EncReader {
return func(
addr *udp.Addr,
out []byte,
packet []byte,
header *header.H,
fwPacket *firewall.Packet,
lhh udp.LightHouseHandlerFunc,
nb []byte,
q int,
localCache firewall.ConntrackCache,
) {
f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
}
}
func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// TODO: best if we return this and let caller log // TODO: best if we return this and let caller log
@ -149,7 +165,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by
return return
} }
lhf(addr, hostinfo.vpnIp, d, f) lhf(addr, hostinfo.vpnIp, d)
// Fallthrough to the bottom to record incoming traffic // Fallthrough to the bottom to record incoming traffic

View File

@ -9,7 +9,6 @@ const MTU = 9001
type EncReader func( type EncReader func(
addr *Addr, addr *Addr,
via interface{},
out []byte, out []byte,
packet []byte, packet []byte,
header *header.H, header *header.H,

View File

@ -1,22 +1,9 @@
package udp package udp
import ( import (
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
) )
type EncWriter interface {
SendVia(via interface{},
relay interface{},
ad,
nb,
out []byte,
nocopy bool,
)
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
Handshake(vpnIp iputil.VpnIp)
}
//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte)

View File

@ -86,6 +86,6 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
udpAddr.IP = rua.IP udpAddr.IP = rua.IP
udpAddr.Port = uint16(rua.Port) udpAddr.Port = uint16(rua.Port)
r(udpAddr, nil, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
} }
} }

View File

@ -145,7 +145,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
udpAddr.IP = names[i][8:24] udpAddr.IP = names[i][8:24]
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
r(udpAddr, nil, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
} }
} }
} }

View File

@ -122,7 +122,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
} }
ua.Port = p.FromPort ua.Port = p.FromPort
copy(ua.IP, p.FromIp.To16()) copy(ua.IP, p.FromIp.To16())
r(ua, nil, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
} }
} }