mirror of https://github.com/slackhq/nebula.git
update EncReader and EncWriter interface function args to have concrete types (#844)
* Update LightHouseHandlerFunc to remove EncWriter param. * Move EncWriter to interface * EncReader, too
This commit is contained in:
parent
3cb4e0ef57
commit
9b03053191
|
@ -5,7 +5,7 @@ import (
|
|||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H, hostinfo *HostInfo) {
|
||||
func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H, hostinfo *HostInfo) {
|
||||
// First remote allow list check before we know the vpnIp
|
||||
if addr != nil {
|
||||
if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
|
||||
|
|
|
@ -68,7 +68,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
|
|||
hostinfo.handshakeStart = time.Now()
|
||||
}
|
||||
|
||||
func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H) {
|
||||
func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
|
||||
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
|
||||
// Mark packet 1 as seen so it doesn't show up as missed
|
||||
ci.window.Update(f.l, 1)
|
||||
|
@ -240,14 +240,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
|
|||
}
|
||||
return
|
||||
} else {
|
||||
via2 := via.(*ViaSender)
|
||||
if via2 == nil {
|
||||
if via == nil {
|
||||
f.l.Error("Handshake send failed: both addr and via are nil.")
|
||||
return
|
||||
}
|
||||
hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
|
||||
f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||
f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via2.relayHI.vpnIp).
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
|
||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||
f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
Info("Handshake message sent")
|
||||
return
|
||||
|
@ -315,14 +314,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
|
|||
Info("Handshake message sent")
|
||||
}
|
||||
} else {
|
||||
via2 := via.(*ViaSender)
|
||||
if via2 == nil {
|
||||
if via == nil {
|
||||
f.l.Error("Handshake send failed: both addr and via are nil.")
|
||||
return
|
||||
}
|
||||
hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
|
||||
f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||
f.l.WithField("vpnIp", vpnIp).WithField("relay", via2.relayHI.vpnIp).
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
|
||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||
f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
|
@ -338,7 +336,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
|
|||
return
|
||||
}
|
||||
|
||||
func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *HostInfo, packet []byte, h *header.H) bool {
|
||||
func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool {
|
||||
if hostinfo == nil {
|
||||
// Nothing here to tear down, got a bogus stage 2 packet
|
||||
return true
|
||||
|
@ -482,8 +480,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *
|
|||
if addr != nil {
|
||||
hostinfo.SetRemote(addr)
|
||||
} else {
|
||||
via2 := via.(*ViaSender)
|
||||
hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
|
||||
}
|
||||
|
||||
// Build up the radix for the firewall if we have subnets in the cert
|
||||
|
|
|
@ -73,7 +73,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
|
|||
}
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
|
||||
func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
|
||||
clockSource := time.NewTicker(c.config.tryInterval)
|
||||
defer clockSource.Stop()
|
||||
|
||||
|
@ -89,7 +89,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) {
|
||||
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
|
||||
c.OutboundHandshakeTimer.Advance(now)
|
||||
for {
|
||||
vpnIp, has := c.OutboundHandshakeTimer.Purge()
|
||||
|
@ -100,7 +100,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.E
|
|||
}
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) {
|
||||
func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) {
|
||||
hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
@ -84,7 +84,7 @@ func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
|
|||
return
|
||||
}
|
||||
|
||||
func (mw *mockEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) {
|
||||
func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -248,16 +248,13 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C
|
|||
// nb is a buffer used to store the nonce value, re-used for performance reasons.
|
||||
// out is a buffer used to store the result of the Encrypt operation
|
||||
// q indicates which writer to use to send the packet.
|
||||
func (f *Interface) SendVia(viaIfc interface{},
|
||||
relayIfc interface{},
|
||||
func (f *Interface) SendVia(via *HostInfo,
|
||||
relay *Relay,
|
||||
ad,
|
||||
nb,
|
||||
out []byte,
|
||||
nocopy bool,
|
||||
) {
|
||||
via := viaIfc.(*HostInfo)
|
||||
relay := relayIfc.(*Relay)
|
||||
|
||||
if noiseutil.EncryptLockNeeded {
|
||||
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
|
||||
via.ConnectionState.writeLock.Lock()
|
||||
|
|
15
interface.go
15
interface.go
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
|
@ -89,6 +90,18 @@ type Interface struct {
|
|||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type EncWriter interface {
|
||||
SendVia(via *HostInfo,
|
||||
relay *Relay,
|
||||
ad,
|
||||
nb,
|
||||
out []byte,
|
||||
nocopy bool,
|
||||
)
|
||||
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
|
||||
Handshake(vpnIp iputil.VpnIp)
|
||||
}
|
||||
|
||||
type sendRecvErrorConfig uint8
|
||||
|
||||
const (
|
||||
|
@ -238,7 +251,7 @@ func (f *Interface) listenOut(i int) {
|
|||
|
||||
lhh := f.lightHouse.NewRequestHandler()
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i)
|
||||
li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
|
||||
}
|
||||
|
||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
|
|
|
@ -65,7 +65,7 @@ type LightHouse struct {
|
|||
interval atomic.Int64
|
||||
updateCancel context.CancelFunc
|
||||
updateParentCtx context.Context
|
||||
updateUdp udp.EncWriter
|
||||
updateUdp EncWriter
|
||||
nebulaPort uint32 // 32 bits because protobuf does not have a uint16
|
||||
|
||||
advertiseAddrs atomic.Pointer[[]netIpAndPort]
|
||||
|
@ -382,7 +382,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
|
|||
return nil
|
||||
}
|
||||
|
||||
func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
|
||||
func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList {
|
||||
if !lh.IsLighthouseIP(ip) {
|
||||
lh.QueryServer(ip, f)
|
||||
}
|
||||
|
@ -396,7 +396,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
|
|||
}
|
||||
|
||||
// This is asynchronous so no reply should be expected
|
||||
func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) {
|
||||
func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f EncWriter) {
|
||||
if lh.amLighthouse {
|
||||
return
|
||||
}
|
||||
|
@ -629,7 +629,7 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
|
|||
return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
|
||||
}
|
||||
|
||||
func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
|
||||
func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
|
||||
lh.updateParentCtx = ctx
|
||||
lh.updateUdp = f
|
||||
|
||||
|
@ -655,7 +655,7 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
|
|||
}
|
||||
}
|
||||
|
||||
func (lh *LightHouse) SendUpdate(f udp.EncWriter) {
|
||||
func (lh *LightHouse) SendUpdate(f EncWriter) {
|
||||
var v4 []*Ip4AndPort
|
||||
var v6 []*Ip6AndPort
|
||||
|
||||
|
@ -760,7 +760,13 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
|
|||
return lhh.meta
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) {
|
||||
func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc {
|
||||
return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) {
|
||||
lhh.HandleRequest(rAddr, vpnIp, p, f)
|
||||
}
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) {
|
||||
n := lhh.resetMeta()
|
||||
err := n.Unmarshal(p)
|
||||
if err != nil {
|
||||
|
@ -795,7 +801,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp,
|
|||
}
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) {
|
||||
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) {
|
||||
// Exit if we don't answer queries
|
||||
if !lhh.lh.amLighthouse {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
|
@ -928,7 +934,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
|
|||
am.Unlock()
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) {
|
||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
|
||||
if !lhh.lh.IsLighthouseIP(vpnIp) {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -372,7 +372,7 @@ type testEncWriter struct {
|
|||
metaFilter *NebulaMeta_MessageType
|
||||
}
|
||||
|
||||
func (tw *testEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) {
|
||||
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
|
||||
}
|
||||
func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) {
|
||||
}
|
||||
|
|
20
outside.go
20
outside.go
|
@ -21,7 +21,23 @@ const (
|
|||
minFwPacketLen = 4
|
||||
)
|
||||
|
||||
func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||
func readOutsidePackets(f *Interface) udp.EncReader {
|
||||
return func(
|
||||
addr *udp.Addr,
|
||||
out []byte,
|
||||
packet []byte,
|
||||
header *header.H,
|
||||
fwPacket *firewall.Packet,
|
||||
lhh udp.LightHouseHandlerFunc,
|
||||
nb []byte,
|
||||
q int,
|
||||
localCache firewall.ConntrackCache,
|
||||
) {
|
||||
f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||
err := h.Parse(packet)
|
||||
if err != nil {
|
||||
// TODO: best if we return this and let caller log
|
||||
|
@ -149,7 +165,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by
|
|||
return
|
||||
}
|
||||
|
||||
lhf(addr, hostinfo.vpnIp, d, f)
|
||||
lhf(addr, hostinfo.vpnIp, d)
|
||||
|
||||
// Fallthrough to the bottom to record incoming traffic
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ const MTU = 9001
|
|||
|
||||
type EncReader func(
|
||||
addr *Addr,
|
||||
via interface{},
|
||||
out []byte,
|
||||
packet []byte,
|
||||
header *header.H,
|
||||
|
|
15
udp/temp.go
15
udp/temp.go
|
@ -1,22 +1,9 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
)
|
||||
|
||||
type EncWriter interface {
|
||||
SendVia(via interface{},
|
||||
relay interface{},
|
||||
ad,
|
||||
nb,
|
||||
out []byte,
|
||||
nocopy bool,
|
||||
)
|
||||
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
|
||||
Handshake(vpnIp iputil.VpnIp)
|
||||
}
|
||||
|
||||
//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
|
||||
|
||||
type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter)
|
||||
type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte)
|
||||
|
|
|
@ -86,6 +86,6 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
|
|||
|
||||
udpAddr.IP = rua.IP
|
||||
udpAddr.Port = uint16(rua.Port)
|
||||
r(udpAddr, nil, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||
r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -145,7 +145,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
|
|||
for i := 0; i < n; i++ {
|
||||
udpAddr.IP = names[i][8:24]
|
||||
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
|
||||
r(udpAddr, nil, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||
r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -122,7 +122,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
|
|||
}
|
||||
ua.Port = p.FromPort
|
||||
copy(ua.IP, p.FromIp.To16())
|
||||
r(ua, nil, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||
r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue