add calculated_remotes (#759)

* add calculated_remotes

This setting allows us to "guess" what the remote might be for a host
while we wait for the lighthouse response. For networks that hard
designed with in mind, it can help speed up handshake performance, as well as
improve resiliency in the case that all lighthouses are down.

Example:

    lighthouse:
      # ...

      calculated_remotes:
        # For any Nebula IPs in 10.0.10.0/24, this will apply the mask and add
        # the calculated IP as an initial remote (while we wait for the response
        # from the lighthouse). Both CIDRs must have the same mask size.
        # For example, Nebula IP 10.0.10.123 will have a calculated remote of
        # 192.168.1.123

        10.0.10.0/24:
          - mask: 192.168.1.0/24
            port: 4242

* figure out what is up with this test

* add test

* better logic for sending handshakes

Keep track of the last light of hosts we sent handshakes to. Only log
handshake sent messages if the list has changed.

Remove the test Test_NewHandshakeManagerTrigger because it is faulty and
makes no sense. It relys on the fact that no handshake packets actually
get sent, but with these changes we would send packets now (which it
should!)

* use atomic.Pointer

* cleanup to make it clearer

* fix typo in example
This commit is contained in:
Wade Simmons 2023-03-13 15:09:08 -04:00 committed by GitHub
parent 6e0ae4f9a3
commit e1af37e46d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 300 additions and 69 deletions

143
calculated_remote.go Normal file
View File

@ -0,0 +1,143 @@
package nebula
import (
"fmt"
"math"
"net"
"strconv"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
)
// This allows us to "guess" what the remote might be for a host while we wait
// for the lighthouse response. See "lighthouse.calculated_remotes" in the
// example config file.
type calculatedRemote struct {
ipNet net.IPNet
maskIP iputil.VpnIp
mask iputil.VpnIp
port uint32
}
func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) {
// Ensure this is an IPv4 mask that we expect
ones, bits := ipNet.Mask.Size()
if ones == 0 || bits != 32 {
return nil, fmt.Errorf("invalid mask: %v", ipNet)
}
if port < 0 || port > math.MaxUint16 {
return nil, fmt.Errorf("invalid port: %d", port)
}
return &calculatedRemote{
ipNet: *ipNet,
maskIP: iputil.Ip2VpnIp(ipNet.IP),
mask: iputil.Ip2VpnIp(ipNet.Mask),
port: uint32(port),
}, nil
}
func (c *calculatedRemote) String() string {
return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
}
func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort {
// Combine the masked bytes of the "mask" IP with the unmasked bytes
// of the overlay IP
masked := (c.maskIP & c.mask) | (ip & ^c.mask)
return &Ip4AndPort{Ip: uint32(masked), Port: c.port}
}
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4, error) {
value := c.Get(k)
if value == nil {
return nil, nil
}
calculatedRemotes := cidr.NewTree4()
rawMap, ok := value.(map[any]any)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
}
for rawKey, rawValue := range rawMap {
rawCIDR, ok := rawKey.(string)
if !ok {
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
_, ipNet, err := net.ParseCIDR(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
entry, err := newCalculatedRemotesListFromConfig(rawValue)
if err != nil {
return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
}
calculatedRemotes.AddCIDR(ipNet, entry)
}
return calculatedRemotes, nil
}
func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
rawList, ok := raw.([]any)
if !ok {
return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
}
var l []*calculatedRemote
for _, e := range rawList {
c, err := newCalculatedRemotesEntryFromConfig(e)
if err != nil {
return nil, fmt.Errorf("calculated_remotes entry: %w", err)
}
l = append(l, c)
}
return l, nil
}
func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
rawMap, ok := raw.(map[any]any)
if !ok {
return nil, fmt.Errorf("invalid type: %T", raw)
}
rawValue := rawMap["mask"]
if rawValue == nil {
return nil, fmt.Errorf("missing mask: %v", rawMap)
}
rawMask, ok := rawValue.(string)
if !ok {
return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue)
}
_, ipNet, err := net.ParseCIDR(rawMask)
if err != nil {
return nil, fmt.Errorf("invalid mask: %s", rawMask)
}
var port int
rawValue = rawMap["port"]
if rawValue == nil {
return nil, fmt.Errorf("missing port: %v", rawMap)
}
switch v := rawValue.(type) {
case int:
port = v
case string:
port, err = strconv.Atoi(v)
if err != nil {
return nil, fmt.Errorf("invalid port: %s: %w", v, err)
}
default:
return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
}
return newCalculatedRemote(ipNet, port)
}

27
calculated_remote_test.go Normal file
View File

@ -0,0 +1,27 @@
package nebula
import (
"net"
"testing"
"github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCalculatedRemoteApply(t *testing.T) {
_, ipNet, err := net.ParseCIDR("192.168.1.0/24")
require.NoError(t, err)
c, err := newCalculatedRemote(ipNet, 4242)
require.NoError(t, err)
input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182})
expected := &Ip4AndPort{
Ip: uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})),
Port: 4242,
}
assert.Equal(t, expected, c.Apply(input))
}

View File

@ -91,6 +91,19 @@ lighthouse:
#- "1.1.1.1:4242"
#- "1.2.3.4:0" # port will be replaced with the real listening port
# EXPERIMENTAL: This option may change or disappear in the future.
# This setting allows us to "guess" what the remote might be for a host
# while we wait for the lighthouse response.
#calculated_remotes:
# For any Nebula IPs in 10.0.10.0/24, this will apply the mask and add
# the calculated IP as an initial remote (while we wait for the response
# from the lighthouse). Both CIDRs must have the same mask size.
# For example, Nebula IP 10.0.10.123 will have a calculated remote of
# 192.168.1.123
#10.0.10.0/24:
#- mask: 192.168.1.0/24
# port: 4242
# Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
# however using port 0 will dynamically assign a port and is recommended for roaming nodes.
listen:

View File

@ -142,14 +142,6 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l
return
}
// We only care about a lighthouse trigger before the first handshake transmit attempt. This is a very specific
// optimization for a fast lighthouse reply
//TODO: it would feel better to do this once, anytime, as our delay increases over time
if lighthouseTriggered && hostinfo.HandshakeCounter > 0 {
// If we didn't return here a lighthouse could cause us to aggressively send handshakes
return
}
// Get a remotes object if we don't already have one.
// This is mainly to protect us as this should never be the case
// NB ^ This comment doesn't jive. It's how the thing gets initialized.
@ -158,8 +150,22 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l
hostinfo.remotes = c.lightHouse.QueryCache(vpnIp)
}
//TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped)
if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 {
remotes := hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)
remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes)
// We only care about a lighthouse trigger if we have new remotes to send to.
// This is a very specific optimization for a fast lighthouse reply.
if lighthouseTriggered && !remotesHaveChanged {
// If we didn't return here a lighthouse could cause us to aggressively send handshakes
return
}
hostinfo.HandshakeLastRemotes = remotes
// TODO: this will generate a load of queries for hosts with only 1 ip
// (such as ones registered to the lighthouse with only a private IP)
// So we only do it one time after attempting 5 handshakes already.
if len(remotes) <= 1 && hostinfo.HandshakeCounter == 5 {
// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
// Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
// the learned public ip for them. Query again to short circuit the promotion counter
@ -182,12 +188,18 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l
}
})
// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout
if len(sentTo) > 0 {
// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
// so only log when the list of remotes has changed
if remotesHaveChanged {
hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message sent")
} else if c.l.IsLevelEnabled(logrus.DebugLevel) {
hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Debug("Handshake message sent")
}
if c.config.useRelays && len(hostinfo.remotes.relays) > 0 {

View File

@ -66,46 +66,6 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
assert.NotContains(t, blah.pendingHostMap.Hosts, ip)
}
func Test_NewHandshakeManagerTrigger(t *testing.T) {
l := test.NewLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
lh := newTestLighthouse()
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
hi := blah.AddVpnIp(ip, nil)
hi.HandshakeReady = true
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
// Trigger the same method the channel will but, this should set our remotes pointer
blah.handleOutbound(ip, mw, true)
assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have done a handshake attempt")
assert.NotNil(t, hi.remotes, "Manager should have set my remotes pointer")
// Make sure the trigger doesn't double schedule the timer entry
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
uaddr := udp.NewAddrFromString("10.1.1.1:4242")
hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port)))
// We now have remotes but only the first trigger should have pushed things forward
blah.handleOutbound(ip, mw, true)
assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have not done a handshake attempt")
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
}
func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
for _, i := range tw.t.wheel {
n := i.Head

View File

@ -155,22 +155,23 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
type HostInfo struct {
sync.RWMutex
remote *udp.Addr
remotes *RemoteList
promoteCounter atomic.Uint32
ConnectionState *ConnectionState
handshakeStart time.Time //todo: this an entry in the handshake manager
HandshakeReady bool //todo: being in the manager means you are ready
HandshakeCounter int //todo: another handshake manager entry
HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready
HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry
packetStore []*cachedPacket //todo: this is other handshake manager entry
remoteIndexId uint32
localIndexId uint32
vpnIp iputil.VpnIp
recvError int
remoteCidr *cidr.Tree4
relayState RelayState
remote *udp.Addr
remotes *RemoteList
promoteCounter atomic.Uint32
ConnectionState *ConnectionState
handshakeStart time.Time //todo: this an entry in the handshake manager
HandshakeReady bool //todo: being in the manager means you are ready
HandshakeCounter int //todo: another handshake manager entry
HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time
HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready
HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry
packetStore []*cachedPacket //todo: this is other handshake manager entry
remoteIndexId uint32
localIndexId uint32
vpnIp iputil.VpnIp
recvError int
remoteCidr *cidr.Tree4
relayState RelayState
// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like

View File

@ -153,7 +153,13 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
// If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now
if _, ok := f.lightHouse.GetStaticHostList()[vpnIp]; ok {
_, doTrigger := f.lightHouse.GetStaticHostList()[vpnIp]
if !doTrigger {
// Add any calculated remotes, and trigger early handshake if one found
doTrigger = f.lightHouse.addCalculatedRemotes(vpnIp)
}
if doTrigger {
select {
case f.handshakeManager.trigger <- vpnIp:
default:

View File

@ -12,6 +12,7 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
@ -72,6 +73,8 @@ type LightHouse struct {
// IP's of relays that can be used by peers to access me
relaysForMe atomic.Pointer[[]iputil.VpnIp]
calculatedRemotes atomic.Pointer[cidr.Tree4] // Maps VpnIp to []*calculatedRemote
metrics *MessageMetrics
metricHolepunchTx metrics.Counter
l *logrus.Logger
@ -161,6 +164,10 @@ func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp {
return *lh.relaysForMe.Load()
}
func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4 {
return lh.calculatedRemotes.Load()
}
func (lh *LightHouse) GetUpdateInterval() int64 {
return lh.interval.Load()
}
@ -237,6 +244,19 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
}
}
if initial || c.HasChanged("lighthouse.calculated_remotes") {
cr, err := NewCalculatedRemotesFromConfig(c, "lighthouse.calculated_remotes")
if err != nil {
return util.NewContextualError("Invalid lighthouse.calculated_remotes", nil, err)
}
lh.calculatedRemotes.Store(cr)
if !initial {
//TODO: a diff will be annoyingly difficult
lh.l.Info("lighthouse.calculated_remotes has changed")
}
}
//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
if initial || c.HasChanged("static_host_map") {
staticList := make(map[iputil.VpnIp]struct{})
@ -488,6 +508,39 @@ func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, stat
staticList[vpnIp] = struct{}{}
}
// addCalculatedRemotes adds any calculated remotes based on the
// lighthouse.calculated_remotes configuration. It returns true if any
// calculated remotes were added
func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
tree := lh.getCalculatedRemotes()
if tree == nil {
return false
}
value := tree.MostSpecificContains(vpnIp)
if value == nil {
return false
}
calculatedRemotes := value.([]*calculatedRemote)
var calculated []*Ip4AndPort
for _, cr := range calculatedRemotes {
c := cr.Apply(vpnIp)
if c != nil {
calculated = append(calculated, c)
}
}
lh.Lock()
am := lh.unlockedGetRemoteList(vpnIp)
am.Lock()
defer am.Unlock()
lh.Unlock()
am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4)
return len(calculated) > 0
}
// unlockedGetRemoteList assumes you have the lh lock
func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
am, ok := lh.addrMap[vpnIp]

View File

@ -64,6 +64,22 @@ func (ua *Addr) Copy() *Addr {
return &nu
}
type AddrSlice []*Addr
func (a AddrSlice) Equal(b AddrSlice) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !a[i].Equals(b[i]) {
return false
}
}
return true
}
func ParseIPAndPort(s string) (net.IP, uint16, error) {
rIp, sPort, err := net.SplitHostPort(s)
if err != nil {