Pull hostmap and pending hostmap apart, remove unused functions (#843)

This commit is contained in:
Nate Brown 2023-07-24 12:37:52 -05:00 committed by GitHub
parent 52c9e360e7
commit a10baeee92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 292 additions and 295 deletions

View File

@ -42,7 +42,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
hostMap := NewHostMap(l, vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
@ -121,7 +121,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
hostMap := NewHostMap(l, vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
@ -207,7 +207,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
hostMap := NewHostMap(l, vpncidr, preferredRanges)
// Generate keys for CA and peer's cert.
pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
@ -268,12 +268,16 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
ifce.connectionManager = nc
hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
peerCert: &peerCert,
H: &noise.HandshakeState{},
hostinfo := &HostInfo{
vpnIp: vpnIp,
ConnectionState: &ConnectionState{
certState: cs,
peerCert: &peerCert,
H: &noise.HandshakeState{},
},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
// Move ahead 45s.
// Check if to disconnect with invalid certificate.

View File

@ -17,6 +17,15 @@ import (
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
type controlEach func(h *HostInfo)
type controlHostLister interface {
QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo
ForEachIndex(each controlEach)
ForEachVpnIp(each controlEach)
GetPreferredRanges() []*net.IPNet
}
type Control struct {
f *Interface
l *logrus.Logger
@ -98,7 +107,7 @@ func (c *Control) RebindUDPServer() {
// ListHostmapHosts returns details about the actual or pending (handshaking) hostmap by vpn ip
func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
if pendingMap {
return listHostMapHosts(c.f.handshakeManager.pendingHostMap)
return listHostMapHosts(c.f.handshakeManager)
} else {
return listHostMapHosts(c.f.hostMap)
}
@ -107,7 +116,7 @@ func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
// ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id
func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
if pendingMap {
return listHostMapIndexes(c.f.handshakeManager.pendingHostMap)
return listHostMapIndexes(c.f.handshakeManager)
} else {
return listHostMapIndexes(c.f.hostMap)
}
@ -115,15 +124,15 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
var hm *HostMap
var hl controlHostLister
if pending {
hm = c.f.handshakeManager.pendingHostMap
hl = c.f.handshakeManager
} else {
hm = c.f.hostMap
hl = c.f.hostMap
}
h, err := hm.QueryVpnIp(vpnIp)
if err != nil {
h := hl.QueryVpnIp(vpnIp)
if h == nil {
return nil
}
@ -133,8 +142,8 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
// SetRemoteForTunnel forces a tunnel to use a specific remote
func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
if err != nil {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return nil
}
@ -145,8 +154,8 @@ func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *Control
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
if err != nil {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return false
}
@ -241,28 +250,20 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
return chi
}
func listHostMapHosts(hm *HostMap) []ControlHostInfo {
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v, hm.preferredRanges)
i++
}
hm.RUnlock()
func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
hosts := make([]ControlHostInfo, 0)
pr := hl.GetPreferredRanges()
hl.ForEachVpnIp(func(hostinfo *HostInfo) {
hosts = append(hosts, copyHostInfo(hostinfo, pr))
})
return hosts
}
func listHostMapIndexes(hm *HostMap) []ControlHostInfo {
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Indexes))
i := 0
for _, v := range hm.Indexes {
hosts[i] = copyHostInfo(v, hm.preferredRanges)
i++
}
hm.RUnlock()
func listHostMapIndexes(hl controlHostLister) []ControlHostInfo {
hosts := make([]ControlHostInfo, 0)
pr := hl.GetPreferredRanges()
hl.ForEachIndex(func(hostinfo *HostInfo) {
hosts = append(hosts, copyHostInfo(hostinfo, pr))
})
return hosts
}

View File

@ -18,7 +18,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0))
remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
ipNet := net.IPNet{
@ -50,7 +50,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
remotes := NewRemoteList(nil)
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{
hm.unlockedAddHostInfo(&HostInfo{
remote: remote1,
remotes: remotes,
ConnectionState: &ConnectionState{
@ -64,9 +64,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
relayForByIp: map[iputil.VpnIp]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
})
}, &Interface{})
hm.Add(iputil.Ip2VpnIp(ipNet2.IP), &HostInfo{
hm.unlockedAddHostInfo(&HostInfo{
remote: remote1,
remotes: remotes,
ConnectionState: &ConnectionState{
@ -80,7 +80,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
relayForByIp: map[iputil.VpnIp]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
})
}, &Interface{})
c := Control{
f: &Interface{

View File

@ -147,12 +147,12 @@ func (c *Control) GetUDPAddr() string {
}
func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[iputil.Ip2VpnIp(vpnIp)]
if !ok {
hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp))
if hostinfo == nil {
return false
}
c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
c.f.handshakeManager.DeleteHostInfo(hostinfo)
return true
}

View File

@ -47,8 +47,8 @@ func (d *dnsRecords) QueryCert(data string) string {
return ""
}
iip := iputil.Ip2VpnIp(ip)
hostinfo, err := d.hostMap.QueryVpnIp(iip)
if err != nil {
hostinfo := d.hostMap.QueryVpnIp(iip)
if hostinfo == nil {
return ""
}
q := hostinfo.GetCert()

View File

@ -20,7 +20,7 @@ func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packe
case 1:
ixHandshakeStage1(f, addr, via, packet, h)
case 2:
newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
newHostinfo := f.handshakeManager.QueryIndex(h.RemoteIndex)
tearDown := ixHandshakeStage2(f, addr, via, newHostinfo, packet, h)
if tearDown && newHostinfo != nil {
f.handshakeManager.DeleteHostInfo(newHostinfo)

View File

@ -422,7 +422,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
Info("Incorrect host responded to handshake")
// Release our old handshake from pending, it should not continue
f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip
//TODO: this adds it to the timer wheel in a way that aggressively retries

View File

@ -7,6 +7,7 @@ import (
"encoding/binary"
"errors"
"net"
"sync"
"time"
"github.com/rcrowley/go-metrics"
@ -42,7 +43,12 @@ type HandshakeConfig struct {
}
type HandshakeManager struct {
pendingHostMap *HostMap
// Mutex for interacting with the vpnIps and indexes maps
sync.RWMutex
vpnIps map[iputil.VpnIp]*HostInfo
indexes map[uint32]*HostInfo
mainHostMap *HostMap
lightHouse *LightHouse
outside udp.Conn
@ -59,7 +65,8 @@ type HandshakeManager struct {
func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{
pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
vpnIps: map[iputil.VpnIp]*HostInfo{},
indexes: map[uint32]*HostInfo{},
mainHostMap: mainHostMap,
lightHouse: lightHouse,
outside: outside,
@ -101,8 +108,8 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr
}
func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) {
hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
if err != nil {
hostinfo := c.QueryVpnIp(vpnIp)
if hostinfo == nil {
return
}
hostinfo.Lock()
@ -111,7 +118,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
// We may have raced to completion but now that we have a lock we should ensure we have not yet completed.
if hostinfo.HandshakeComplete {
// Ensure we don't exist in the pending hostmap anymore since we have completed
c.pendingHostMap.DeleteHostInfo(hostinfo)
c.DeleteHostInfo(hostinfo)
return
}
@ -125,14 +132,14 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
// If we are out of time, clean up
if hostinfo.HandshakeCounter >= c.config.retries {
hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)).
hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()).
Info("Handshake timed out")
c.metricTimedOut.Inc(1)
c.pendingHostMap.DeleteHostInfo(hostinfo)
c.DeleteHostInfo(hostinfo)
return
}
@ -144,7 +151,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
hostinfo.remotes = c.lightHouse.QueryCache(vpnIp)
}
remotes := hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)
remotes := hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)
remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes)
// We only care about a lighthouse trigger if we have new remotes to send to.
@ -168,9 +175,9 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []*udp.Addr
hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
hostinfo.remotes.ForEach(c.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil {
hostinfo.logger(c.l).WithField("udpAddr", addr).
WithField("initiatorIndex", hostinfo.localIndexId).
@ -204,9 +211,9 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
if *relay == vpnIp || *relay == c.lightHouse.myVpnIp {
continue
}
relayHostInfo, err := c.mainHostMap.QueryVpnIp(*relay)
if err != nil || relayHostInfo.remote == nil {
hostinfo.logger(c.l).WithError(err).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
relayHostInfo := c.mainHostMap.QueryVpnIp(*relay)
if relayHostInfo == nil || relayHostInfo.remote == nil {
hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
f.Handshake(*relay)
continue
}
@ -289,14 +296,35 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
}
}
// AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it.
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo {
hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init)
// A write lock is used to avoid having to recheck the map and trading a read lock for a write lock
c.Lock()
defer c.Unlock()
if created {
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
c.metricInitiated.Inc(1)
if hostinfo, ok := c.vpnIps[vpnIp]; ok {
// We are already tracking this vpn ip
return hostinfo
}
hostinfo := &HostInfo{
vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{
relays: map[iputil.VpnIp]struct{}{},
relayForByIp: map[iputil.VpnIp]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
if init != nil {
init(hostinfo)
}
c.vpnIps[vpnIp] = hostinfo
c.metricInitiated.Inc(1)
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
return hostinfo
}
@ -318,8 +346,8 @@ var (
// ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId.
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock()
c.Lock()
defer c.Unlock()
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
@ -350,7 +378,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
return existingIndex, ErrLocalIndexCollision
}
existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
existingIndex, found = c.indexes[hostinfo.localIndexId]
if found && existingIndex != hostinfo {
// We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision
@ -373,8 +401,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
// won't have a localIndexId collision because we already have an entry in the
// pendingHostMap. An existing hostinfo is returned if there was one.
func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock()
c.Lock()
defer c.Unlock()
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
@ -388,7 +416,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
}
// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
c.unlockedDeleteHostInfo(hostinfo)
c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
}
@ -396,8 +424,8 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
// and adds it to the pendingHostMap. Will error if we are unable to generate
// a unique localIndexId
func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
c.pendingHostMap.Lock()
defer c.pendingHostMap.Unlock()
c.Lock()
defer c.Unlock()
c.mainHostMap.RLock()
defer c.mainHostMap.RUnlock()
@ -407,12 +435,12 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
return err
}
_, inPending := c.pendingHostMap.Indexes[index]
_, inPending := c.indexes[index]
_, inMain := c.mainHostMap.Indexes[index]
if !inMain && !inPending {
h.localIndexId = index
c.pendingHostMap.Indexes[index] = h
c.indexes[index] = h
return nil
}
}
@ -420,22 +448,73 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
return errors.New("failed to generate unique localIndexId")
}
func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
c.pendingHostMap.addRemoteIndexHostInfo(index, h)
}
func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
//l.Debugln("Deleting pending hostinfo :", hostinfo)
c.pendingHostMap.DeleteHostInfo(hostinfo)
c.Lock()
defer c.Unlock()
c.unlockedDeleteHostInfo(hostinfo)
}
func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) {
return c.pendingHostMap.QueryIndex(index)
func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
delete(c.vpnIps, hostinfo.vpnIp)
if len(c.vpnIps) == 0 {
c.vpnIps = map[iputil.VpnIp]*HostInfo{}
}
delete(c.indexes, hostinfo.localIndexId)
if len(c.vpnIps) == 0 {
c.indexes = map[uint32]*HostInfo{}
}
if c.l.Level >= logrus.DebugLevel {
c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps),
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Pending hostmap hostInfo deleted")
}
}
func (c *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
c.RLock()
defer c.RUnlock()
return c.vpnIps[vpnIp]
}
func (c *HandshakeManager) QueryIndex(index uint32) *HostInfo {
c.RLock()
defer c.RUnlock()
return c.indexes[index]
}
func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
return c.mainHostMap.preferredRanges
}
func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
c.RLock()
defer c.RUnlock()
for _, v := range c.vpnIps {
f(v)
}
}
func (c *HandshakeManager) ForEachIndex(f controlEach) {
c.RLock()
defer c.RUnlock()
for _, v := range c.indexes {
f(v)
}
}
func (c *HandshakeManager) EmitStats() {
c.pendingHostMap.EmitStats("pending")
c.mainHostMap.EmitStats("main")
c.RLock()
hostLen := len(c.vpnIps)
indexLen := len(c.indexes)
c.RUnlock()
metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
c.mainHostMap.EmitStats()
}
// Utility functions below

View File

@ -20,7 +20,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
mainHM := NewHostMap(l, vpncidr, preferredRanges)
lh := newTestLighthouse()
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@ -48,7 +48,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
assert.Len(t, mainHM.Hosts, 0)
// Confirm they are in the pending index list
assert.Contains(t, blah.pendingHostMap.Hosts, ip)
assert.Contains(t, blah.vpnIps, ip)
// Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right
for i := 1; i <= DefaultHandshakeRetries+1; i++ {
@ -57,13 +57,13 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
}
// Confirm they are still in the pending index list
assert.Contains(t, blah.pendingHostMap.Hosts, ip)
assert.Contains(t, blah.vpnIps, ip)
// Tick 1 more time, a minute will certainly flush it out
blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw)
// Confirm they have been removed
assert.NotContains(t, blah.pendingHostMap.Hosts, ip)
assert.NotContains(t, blah.vpnIps, ip)
}
func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {

View File

@ -2,7 +2,6 @@ package nebula
import (
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
@ -52,7 +51,6 @@ type Relay struct {
type HostMap struct {
sync.RWMutex //Because we concurrently read and write to our maps
name string
Indexes map[uint32]*HostInfo
Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
RemoteIndexes map[uint32]*HostInfo
@ -203,13 +201,13 @@ type HostInfo struct {
remotes *RemoteList
promoteCounter atomic.Uint32
ConnectionState *ConnectionState
handshakeStart time.Time //todo: this an entry in the handshake manager
HandshakeReady bool //todo: being in the manager means you are ready
HandshakeCounter int //todo: another handshake manager entry
HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time
HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready
HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry
packetStore []*cachedPacket //todo: this is other handshake manager entry
handshakeStart time.Time //todo: this an entry in the handshake manager
HandshakeReady bool //todo: being in the manager means you are ready
HandshakeCounter int //todo: another handshake manager entry
HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time
HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready
HandshakePacket map[uint8][]byte
packetStore []*cachedPacket //todo: this is other handshake manager entry
remoteIndexId uint32
localIndexId uint32
vpnIp iputil.VpnIp
@ -255,13 +253,12 @@ type cachedPacketMetrics struct {
dropped metrics.Counter
}
func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
h := map[iputil.VpnIp]*HostInfo{}
i := map[uint32]*HostInfo{}
r := map[uint32]*HostInfo{}
relays := map[uint32]*HostInfo{}
m := HostMap{
name: name,
Indexes: i,
Relays: relays,
RemoteIndexes: r,
@ -273,8 +270,8 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
return &m
}
// UpdateStats takes a name and reports host and index counts to the stats collection system
func (hm *HostMap) EmitStats(name string) {
// EmitStats reports host, index, and relay counts to the stats collection system
func (hm *HostMap) EmitStats() {
hm.RLock()
hostLen := len(hm.Hosts)
indexLen := len(hm.Indexes)
@ -282,10 +279,10 @@ func (hm *HostMap) EmitStats(name string) {
relaysLen := len(hm.Relays)
hm.RUnlock()
metrics.GetOrRegisterGauge("hostmap."+name+".hosts", nil).Update(int64(hostLen))
metrics.GetOrRegisterGauge("hostmap."+name+".indexes", nil).Update(int64(indexLen))
metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
metrics.GetOrRegisterGauge("hostmap."+name+".relayIndexes", nil).Update(int64(relaysLen))
metrics.GetOrRegisterGauge("hostmap.main.hosts", nil).Update(int64(hostLen))
metrics.GetOrRegisterGauge("hostmap.main.indexes", nil).Update(int64(indexLen))
metrics.GetOrRegisterGauge("hostmap.main.remoteIndexes", nil).Update(int64(remoteIndexLen))
metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
}
func (hm *HostMap) RemoveRelay(localIdx uint32) {
@ -299,88 +296,6 @@ func (hm *HostMap) RemoveRelay(localIdx uint32) {
hm.Unlock()
}
func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) {
hm.RLock()
if i, ok := hm.Hosts[vpnIp]; ok {
index := i.localIndexId
hm.RUnlock()
return index, nil
}
hm.RUnlock()
return 0, errors.New("vpn IP not found")
}
func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) {
hm.Lock()
hm.Hosts[ip] = hostinfo
hm.Unlock()
}
func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) (hostinfo *HostInfo, created bool) {
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; !ok {
hm.RUnlock()
h = &HostInfo{
vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{
relays: map[iputil.VpnIp]struct{}{},
relayForByIp: map[iputil.VpnIp]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
if init != nil {
init(h)
}
hm.Lock()
hm.Hosts[vpnIp] = h
hm.Unlock()
return h, true
} else {
hm.RUnlock()
return h, false
}
}
// Only used by pendingHostMap when the remote index is not initially known
func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
hm.Lock()
h.remoteIndexId = index
hm.RemoteIndexes[index] = h
hm.Unlock()
if hm.l.Level > logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}).
Debug("Hostmap remoteIndex added")
}
}
// DeleteReverseIndex is used to clean up on recv_error
// This function should only ever be called on the pending hostmap
func (hm *HostMap) DeleteReverseIndex(index uint32) {
hm.Lock()
hostinfo, ok := hm.RemoteIndexes[index]
if ok {
delete(hm.Indexes, hostinfo.localIndexId)
delete(hm.RemoteIndexes, index)
// Check if we have an entry under hostId that matches the same hostinfo
// instance. Clean it up as well if we do (they might not match in pendingHostmap)
var hostinfo2 *HostInfo
hostinfo2, ok = hm.Hosts[hostinfo.vpnIp]
if ok && hostinfo2 == hostinfo {
delete(hm.Hosts, hostinfo.vpnIp)
}
}
hm.Unlock()
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
Debug("Hostmap remote index deleted")
}
}
// DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
// Delete the host itself, ensuring it's not modified anymore
@ -393,12 +308,6 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
return final
}
func (hm *HostMap) DeleteRelayIdx(localIdx uint32) {
hm.Lock()
defer hm.Unlock()
delete(hm.RemoteIndexes, localIdx)
}
func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
hm.Lock()
defer hm.Unlock()
@ -476,7 +385,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
}
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted")
}
@ -486,55 +395,41 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
}
}
func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) {
//TODO: we probably just want to return bool instead of error, or at least a static error
func (hm *HostMap) QueryIndex(index uint32) *HostInfo {
hm.RLock()
if h, ok := hm.Indexes[index]; ok {
hm.RUnlock()
return h, nil
return h
} else {
hm.RUnlock()
return nil, errors.New("unable to find index")
return nil
}
}
// Retrieves a HostInfo by Index. Returns whether the HostInfo is primary at time of query.
// This helper exists so that the hostinfo.prev pointer can be read while the hostmap lock is held.
func (hm *HostMap) QueryIndexIsPrimary(index uint32) (*HostInfo, bool, error) {
//TODO: we probably just want to return bool instead of error, or at least a static error
hm.RLock()
if h, ok := hm.Indexes[index]; ok {
hm.RUnlock()
return h, h.prev == nil, nil
} else {
hm.RUnlock()
return nil, false, errors.New("unable to find index")
}
}
func (hm *HostMap) QueryRelayIndex(index uint32) (*HostInfo, error) {
func (hm *HostMap) QueryRelayIndex(index uint32) *HostInfo {
//TODO: we probably just want to return bool instead of error, or at least a static error
hm.RLock()
if h, ok := hm.Relays[index]; ok {
hm.RUnlock()
return h, nil
return h
} else {
hm.RUnlock()
return nil, errors.New("unable to find index")
return nil
}
}
func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
hm.RLock()
if h, ok := hm.RemoteIndexes[index]; ok {
hm.RUnlock()
return h, nil
return h
} else {
hm.RUnlock()
return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name)
return nil
}
}
func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) {
func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
return hm.queryVpnIp(vpnIp, nil)
}
@ -558,11 +453,11 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host
// PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every
// `PromoteEvery` calls to this function for a given host.
func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) {
func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) *HostInfo {
return hm.queryVpnIp(vpnIp, ifce)
}
func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) {
func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo {
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok {
hm.RUnlock()
@ -570,12 +465,12 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*Host
if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
h.TryPromoteBest(hm.preferredRanges, promoteIfce)
}
return h, nil
return h
}
hm.RUnlock()
return nil, errors.New("unable to find host")
return nil
}
// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
@ -598,7 +493,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
Debug("Hostmap vpnIp added")
}
@ -614,6 +509,28 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
}
}
func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
return hm.preferredRanges
}
func (hm *HostMap) ForEachVpnIp(f controlEach) {
hm.RLock()
defer hm.RUnlock()
for _, v := range hm.Hosts {
f(v)
}
}
func (hm *HostMap) ForEachIndex(f controlEach) {
hm.RLock()
defer hm.RUnlock()
for _, v := range hm.Indexes {
f(v)
}
}
// TryPromoteBest handles re-querying lighthouses and probing for better paths
// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {

View File

@ -11,7 +11,7 @@ import (
func TestHostMap_MakePrimary(t *testing.T) {
l := test.NewLogger()
hm := NewHostMap(
l, "test",
l,
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
@ -32,7 +32,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.unlockedAddHostInfo(h1, f)
// Make sure we go h1 -> h2 -> h3 -> h4
prim, _ := hm.QueryVpnIp(1)
prim := hm.QueryVpnIp(1)
assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -47,7 +47,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h3)
// Make sure we go h3 -> h1 -> h2 -> h4
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Equal(t, h3.localIndexId, prim.localIndexId)
assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -62,7 +62,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -77,7 +77,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -92,7 +92,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
func TestHostMap_DeleteHostInfo(t *testing.T) {
l := test.NewLogger()
hm := NewHostMap(
l, "test",
l,
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
@ -119,11 +119,11 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
// h6 should be deleted
assert.Nil(t, h6.next)
assert.Nil(t, h6.prev)
_, err := hm.QueryIndex(h6.localIndexId)
assert.Error(t, err)
h := hm.QueryIndex(h6.localIndexId)
assert.Nil(t, h)
// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
prim, _ := hm.QueryVpnIp(1)
prim := hm.QueryVpnIp(1)
assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -142,7 +142,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h1.next)
// Make sure we go h2 -> h3 -> h4 -> h5
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -160,7 +160,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h3.next)
// Make sure we go h2 -> h4 -> h5
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -176,7 +176,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h5.next)
// Make sure we go h2 -> h4
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -190,7 +190,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h2.next)
// Make sure we only have h4
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Nil(t, prim.prev)
assert.Nil(t, prim.next)
@ -202,6 +202,6 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h4.next)
// Make sure we have nil
prim, _ = hm.QueryVpnIp(1)
prim = hm.QueryVpnIp(1)
assert.Nil(t, prim)
}

View File

@ -121,14 +121,10 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
return nil
}
}
hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
//if err != nil || hostinfo.ConnectionState == nil {
if err != nil {
hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
if err != nil {
hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
}
hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
if hostinfo == nil {
hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
}
ci := hostinfo.ConnectionState
@ -137,6 +133,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
}
// Handshake is not ready, we need to grab the lock now before we start the handshake process
//TODO: move this to handshake manager
hostinfo.Lock()
defer hostinfo.Unlock()

View File

@ -212,7 +212,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}
hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
hostMap := NewHostMap(l, tunCidr, preferredRanges)
hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
l.
@ -339,7 +339,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
//TODO: check if we _should_ be emitting stats
go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
attachCommands(l, c, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
attachCommands(l, c, ssh, ifce)
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
var dnsStart func()

View File

@ -64,9 +64,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
var hostinfo *HostInfo
// verify if we've seen this index before, otherwise respond to the handshake initiation
if h.Type == header.Message && h.Subtype == header.MessageRelay {
hostinfo, _ = f.hostMap.QueryRelayIndex(h.RemoteIndex)
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
} else {
hostinfo, _ = f.hostMap.QueryIndex(h.RemoteIndex)
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
}
var ci *ConnectionState
@ -449,12 +449,9 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
Debug("Recv error received")
}
// First, clean up in the pending hostmap
f.handshakeManager.pendingHostMap.DeleteReverseIndex(h.RemoteIndex)
hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex)
if err != nil {
f.l.Debugln(err, ": ", h.RemoteIndex)
hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
if hostinfo == nil {
f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap")
return
}
@ -464,14 +461,14 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
if !hostinfo.RecvErrorExceeded() {
return
}
if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return
}
f.closeTunnel(hostinfo)
// We also delete it from pending hostmap to allow for
// fast reconnect.
// We also delete it from pending hostmap to allow for fast reconnect.
f.handshakeManager.DeleteHostInfo(hostinfo)
}

View File

@ -131,9 +131,9 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
return
}
// I'm the middle man. Let the initiator know that the I've established the relay they requested.
peerHostInfo, err := rm.hostmap.QueryVpnIp(relay.PeerIp)
if err != nil {
rm.l.WithError(err).WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp)
if peerHostInfo == nil {
rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
return
}
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target)
@ -240,8 +240,8 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
if !rm.GetAmRelay() {
return
}
peer, err := rm.hostmap.QueryVpnIp(target)
if err != nil {
peer := rm.hostmap.QueryVpnIp(target)
if peer == nil {
// Try to establish a connection to this host. If we get a future relay request,
// we'll be ready!
f.getOrHandshake(target)
@ -253,6 +253,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
}
sendCreateRequest := false
var index uint32
var err error
targetRelay, ok := peer.relayState.QueryRelayForByIp(from)
if ok {
index = targetRelay.LocalIndex

61
ssh.go
View File

@ -3,6 +3,7 @@ package nebula
import (
"bytes"
"encoding/json"
"errors"
"flag"
"fmt"
"io/ioutil"
@ -168,7 +169,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
return runner, nil
}
func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
ssh.RegisterCommand(&sshd.Command{
Name: "list-hostmap",
ShortDescription: "List all known previously connected hosts",
@ -181,7 +182,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshListHostMap(hostMap, fs, w)
return sshListHostMap(f.hostMap, fs, w)
},
})
@ -197,7 +198,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshListHostMap(pendingHostMap, fs, w)
return sshListHostMap(f.handshakeManager, fs, w)
},
})
@ -212,7 +213,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshListLighthouseMap(lightHouse, fs, w)
return sshListLighthouseMap(f.lightHouse, fs, w)
},
})
@ -277,7 +278,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
Name: "version",
ShortDescription: "Prints the currently running version of nebula",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshVersion(ifce, fs, a, w)
return sshVersion(f, fs, a, w)
},
})
@ -293,7 +294,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshPrintCert(ifce, fs, a, w)
return sshPrintCert(f, fs, a, w)
},
})
@ -307,7 +308,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshPrintTunnel(ifce, fs, a, w)
return sshPrintTunnel(f, fs, a, w)
},
})
@ -321,7 +322,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshPrintRelays(ifce, fs, a, w)
return sshPrintRelays(f, fs, a, w)
},
})
@ -335,7 +336,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshChangeRemote(ifce, fs, a, w)
return sshChangeRemote(f, fs, a, w)
},
})
@ -349,7 +350,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshCloseTunnel(ifce, fs, a, w)
return sshCloseTunnel(f, fs, a, w)
},
})
@ -364,7 +365,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshCreateTunnel(ifce, fs, a, w)
return sshCreateTunnel(f, fs, a, w)
},
})
@ -373,12 +374,12 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
ShortDescription: "Query the lighthouses for the provided vpn ip",
Help: "This command is asynchronous. Only currently known udp ips will be printed.",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshQueryLighthouse(ifce, fs, a, w)
return sshQueryLighthouse(f, fs, a, w)
},
})
}
func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error {
func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error {
fs, ok := a.(*sshListHostMapFlags)
if !ok {
//TODO: error
@ -387,9 +388,9 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
var hm []ControlHostInfo
if fs.ByIndex {
hm = listHostMapIndexes(hostMap)
hm = listHostMapIndexes(hl)
} else {
hm = listHostMapHosts(hostMap)
hm = listHostMapHosts(hl)
}
sort.Slice(hm, func(i, j int) bool {
@ -546,8 +547,8 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil {
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
@ -588,12 +589,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, _ := ifce.hostMap.QueryVpnIp(vpnIp)
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
}
hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp)
if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
}
@ -645,8 +646,8 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil {
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
@ -765,8 +766,8 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil {
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
@ -851,9 +852,9 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
for k, v := range relays {
ro := RelayOutput{NebulaIp: v.vpnIp}
co.Relays = append(co.Relays, &ro)
relayHI, err := ifce.hostMap.QueryVpnIp(v.vpnIp)
if err != nil {
ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: err})
relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp)
if relayHI == nil {
ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")})
continue
}
for _, vpnIp := range relayHI.relayState.CopyRelayForIps() {
@ -889,8 +890,8 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
}
}
relayedHI, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err == nil {
relayedHI := ifce.hostMap.QueryVpnIp(vpnIp)
if relayedHI != nil {
rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
}
@ -925,8 +926,8 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
if err != nil {
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}