diff --git a/connection_manager.go b/connection_manager.go index 14086ac..3d0f8aa 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -1,6 +1,7 @@ package nebula import ( + "bytes" "context" "sync" "time" @@ -8,9 +9,20 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" ) +type trafficDecision int + +const ( + doNothing trafficDecision = 0 + deleteTunnel trafficDecision = 1 // delete the hostinfo on our side, do not notify the remote + closeTunnel trafficDecision = 2 // delete the hostinfo and notify the remote + swapPrimary trafficDecision = 3 + migrateRelays trafficDecision = 4 +) + type connectionManager struct { in map[uint32]struct{} inLock *sync.RWMutex @@ -18,6 +30,10 @@ type connectionManager struct { out map[uint32]struct{} outLock *sync.RWMutex + // relayUsed holds which relay localIndexs are in use + relayUsed map[uint32]struct{} + relayUsedLock *sync.RWMutex + hostMap *HostMap trafficTimer *LockingTimerWheel[uint32] intf *Interface @@ -44,6 +60,8 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface inLock: &sync.RWMutex{}, out: make(map[uint32]struct{}), outLock: &sync.RWMutex{}, + relayUsed: make(map[uint32]struct{}), + relayUsedLock: &sync.RWMutex{}, trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max), intf: intf, pendingDeletion: make(map[uint32]struct{}), @@ -84,6 +102,19 @@ func (n *connectionManager) Out(localIndex uint32) { n.outLock.Unlock() } +func (n *connectionManager) RelayUsed(localIndex uint32) { + n.relayUsedLock.RLock() + // If this already exists, return + if _, ok := n.relayUsed[localIndex]; ok { + n.relayUsedLock.RUnlock() + return + } + n.relayUsedLock.RUnlock() + n.relayUsedLock.Lock() + n.relayUsed[localIndex] = struct{}{} + n.relayUsedLock.Unlock() +} + // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and // resets the state for this local index func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) { @@ -99,8 +130,15 @@ func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bo } func (n *connectionManager) AddTrafficWatch(localIndex uint32) { - n.Out(localIndex) + // Use a write lock directly because it should be incredibly rare that we are ever already tracking this index + n.outLock.Lock() + if _, ok := n.out[localIndex]; ok { + n.outLock.Unlock() + return + } + n.out[localIndex] = struct{}{} n.trafficTimer.Add(localIndex, n.checkInterval) + n.outLock.Unlock() } func (n *connectionManager) Start(ctx context.Context) { @@ -136,18 +174,130 @@ func (n *connectionManager) Run(ctx context.Context) { } func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { - hostinfo, err := n.hostMap.QueryIndex(localIndex) - if err != nil { + decision, hostinfo, primary := n.makeTrafficDecision(localIndex, p, nb, out, now) + + switch decision { + case deleteTunnel: + n.hostMap.DeleteHostInfo(hostinfo) + + case closeTunnel: + n.intf.sendCloseTunnel(hostinfo) + n.intf.closeTunnel(hostinfo) + + case swapPrimary: + n.swapPrimary(hostinfo, primary) + + case migrateRelays: + n.migrateRelayUsed(hostinfo, primary) + } + + n.resetRelayTrafficCheck(hostinfo) +} + +func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) { + if hostinfo != nil { + n.relayUsedLock.Lock() + defer n.relayUsedLock.Unlock() + // No need to migrate any relays, delete usage info now. + for _, idx := range hostinfo.relayState.CopyRelayForIdxs() { + delete(n.relayUsed, idx) + } + } +} + +func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) { + relayFor := oldhostinfo.relayState.CopyAllRelayFor() + + for _, r := range relayFor { + existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) + + var index uint32 + var relayFrom iputil.VpnIp + var relayTo iputil.VpnIp + switch { + case ok && existing.State == Established: + // This relay already exists in newhostinfo, then do nothing. + continue + case ok && existing.State == Requested: + // The relay exists in a Requested state; re-send the request + index = existing.LocalIndex + switch r.Type { + case TerminalType: + relayFrom = newhostinfo.vpnIp + relayTo = existing.PeerIp + case ForwardingType: + relayFrom = existing.PeerIp + relayTo = newhostinfo.vpnIp + default: + // should never happen + } + case !ok: + n.relayUsedLock.RLock() + if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed { + // The relay hasn't been used; don't migrate it. + n.relayUsedLock.RUnlock() + continue + } + n.relayUsedLock.RUnlock() + // The relay doesn't exist at all; create some relay state and send the request. + var err error + index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested) + if err != nil { + n.l.WithError(err).Error("failed to migrate relay to new hostinfo") + continue + } + switch r.Type { + case TerminalType: + relayFrom = newhostinfo.vpnIp + relayTo = r.PeerIp + case ForwardingType: + relayFrom = r.PeerIp + relayTo = newhostinfo.vpnIp + default: + // should never happen + } + } + + // Send a CreateRelayRequest to the peer. + req := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: index, + RelayFromIp: uint32(relayFrom), + RelayToIp: uint32(relayTo), + } + msg, err := req.Marshal() + if err != nil { + n.l.WithError(err).Error("failed to marshal Control message to migrate relay") + } else { + n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) + n.l.WithFields(logrus.Fields{ + "relayFrom": iputil.VpnIp(req.RelayFromIp), + "relayTo": iputil.VpnIp(req.RelayToIp), + "initiatorRelayIndex": req.InitiatorRelayIndex, + "responderRelayIndex": req.ResponderRelayIndex, + "vpnIp": newhostinfo.vpnIp}). + Info("send CreateRelayRequest") + } + } +} + +func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []byte, now time.Time) (trafficDecision, *HostInfo, *HostInfo) { + n.hostMap.RLock() + defer n.hostMap.RUnlock() + + hostinfo := n.hostMap.Indexes[localIndex] + if hostinfo == nil { n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap") delete(n.pendingDeletion, localIndex) - return + return doNothing, nil, nil } - if n.handleInvalidCertificate(now, hostinfo) { - return + if n.isInvalidCertificate(now, hostinfo) { + delete(n.pendingDeletion, hostinfo.localIndexId) + return closeTunnel, hostinfo, nil } - primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp) + primary := n.hostMap.Hosts[hostinfo.vpnIp] mainHostInfo := true if primary != nil && primary != hostinfo { mainHostInfo = false @@ -158,6 +308,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, // A hostinfo is determined alive if there is incoming traffic if inTraffic { + decision := doNothing if n.l.Level >= logrus.DebugLevel { hostinfo.logger(n.l). WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). @@ -165,11 +316,14 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, } delete(n.pendingDeletion, hostinfo.localIndexId) - if !mainHostInfo { - if hostinfo.vpnIp > n.intf.myVpnIp { - // We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make - // This the primary and prime the old primary hostinfo for testing - n.hostMap.MakePrimary(hostinfo) + if mainHostInfo { + n.tryRehandshake(hostinfo) + } else { + if n.shouldSwapPrimary(hostinfo, primary) { + decision = swapPrimary + } else { + // migrate the relays to the primary, if in use. + decision = migrateRelays } } @@ -180,7 +334,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, n.sendPunch(hostinfo) } - return + return decision, hostinfo, primary } if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok { @@ -189,22 +343,17 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, WithField("tunnelCheck", m{"state": "dead", "method": "active"}). Info("Tunnel status") - n.hostMap.DeleteHostInfo(hostinfo) delete(n.pendingDeletion, hostinfo.localIndexId) - return + return deleteTunnel, hostinfo, nil } - hostinfo.logger(n.l). - WithField("tunnelCheck", m{"state": "testing", "method": "active"}). - Debug("Tunnel status") - if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { if !outTraffic { // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. // Just maintain NAT state if configured to do so. n.sendPunch(hostinfo) n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) - return + return doNothing, nil, nil } @@ -218,22 +367,58 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) { // We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) - return + return doNothing, nil, nil + } + + if n.l.Level >= logrus.DebugLevel { + hostinfo.logger(n.l). + WithField("tunnelCheck", m{"state": "testing", "method": "active"}). + Debug("Tunnel status") } // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out) + n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) } else { - hostinfo.logger(n.l).Debugf("Hostinfo sadness") + if n.l.Level >= logrus.DebugLevel { + hostinfo.logger(n.l).Debugf("Hostinfo sadness") + } } n.pendingDeletion[hostinfo.localIndexId] = struct{}{} n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval) + return doNothing, nil, nil } -// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid -func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { +func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { + + // The primary tunnel is the most recent handshake to complete locally and should work entirely fine. + // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. + // Let's sort this out. + + if current.vpnIp < n.intf.myVpnIp { + // Only one side should flip primary because if both flip then we may never resolve to a single tunnel. + // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. + // The remotes vpn ip is lower than mine. I will not flip. + return false + } + + certState := n.intf.certState.Load() + return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) +} + +func (n *connectionManager) swapPrimary(current, primary *HostInfo) { + n.hostMap.Lock() + // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. + if n.hostMap.Hosts[current.vpnIp] == primary { + n.hostMap.unlockedMakePrimary(current) + } + n.hostMap.Unlock() +} + +// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and +// the certificate is no longer valid +func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { if !n.intf.disconnectInvalid { return false } @@ -253,10 +438,6 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *Ho WithField("fingerprint", fingerprint). Info("Remote certificate is no longer valid, tearing down the tunnel") - // Inform the remote and close the tunnel locally - n.intf.sendCloseTunnel(hostinfo) - n.intf.closeTunnel(hostinfo) - delete(n.pendingDeletion, hostinfo.localIndexId) return true } @@ -277,3 +458,29 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { n.intf.outside.WriteTo([]byte{1}, hostinfo.remote) } } + +func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { + certState := n.intf.certState.Load() + if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) { + return + } + + n.l.WithField("vpnIp", hostinfo.vpnIp). + WithField("reason", "local certificate is not current"). + Info("Re-handshaking with remote") + + //TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out + newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo) + if !newHostinfo.HandshakeReady { + ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo) + } + + //If this is a static host, we don't need to wait for the HostQueryReply + //We can trigger the handshake right now + if _, ok := n.intf.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok { + select { + case n.intf.handshakeManager.trigger <- hostinfo.vpnIp: + default: + } + } +} diff --git a/connection_manager_test.go b/connection_manager_test.go index 3d79cb0..2ea906b 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -279,13 +279,13 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { // Check if to disconnect with invalid certificate. // Should be alive. nextTick := now.Add(45 * time.Second) - destroyed := nc.handleInvalidCertificate(nextTick, hostinfo) - assert.False(t, destroyed) + invalid := nc.isInvalidCertificate(nextTick, hostinfo) + assert.False(t, invalid) // Move ahead 61s. // Check if to disconnect with invalid certificate. // Should be disconnected. nextTick = now.Add(61 * time.Second) - destroyed = nc.handleInvalidCertificate(nextTick, hostinfo) - assert.True(t, destroyed) + invalid = nc.isInvalidCertificate(nextTick, hostinfo) + assert.True(t, invalid) } diff --git a/control_tester.go b/control_tester.go index 550c986..48deb13 100644 --- a/control_tester.go +++ b/control_tester.go @@ -163,3 +163,17 @@ func (c *Control) GetHostmap() *HostMap { func (c *Control) GetCert() *cert.NebulaCertificate { return c.f.certState.Load().certificate } + +func (c *Control) ReHandshake(vpnIp iputil.VpnIp) { + hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo) + ixHandshakeStage0(c.f, vpnIp, hostinfo) + + // If this is a static host, we don't need to wait for the HostQueryReply + // We can trigger the handshake right now + if _, ok := c.f.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok { + select { + case c.f.handshakeManager.trigger <- hostinfo.vpnIp: + default: + } + } +} diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 8e33deb..aa62603 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -4,6 +4,7 @@ package e2e import ( + "fmt" "net" "testing" "time" @@ -15,12 +16,13 @@ import ( "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v2" ) func BenchmarkHotPath(b *testing.B) { ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) // Put their info in our lighthouse myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) @@ -43,8 +45,8 @@ func BenchmarkHotPath(b *testing.B) { func TestGoodHandshake(t *testing.T) { ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) // Put their info in our lighthouse myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) @@ -98,9 +100,9 @@ func TestWrongResponderHandshake(t *testing.T) { // The IPs here are chosen on purpose: // The current remote handling will sort by preference, public, and then lexically. // So we need them to have a higher address than evil (we could apply a preference though) - myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil) - theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil) - evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil) // Add their real udp addr, which should be tried after evil. myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) @@ -163,8 +165,8 @@ func TestStage1Race(t *testing.T) { // But will eventually collapse down to a single tunnel ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) // Put their info in our lighthouse and vice versa myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) @@ -240,8 +242,8 @@ func TestStage1Race(t *testing.T) { func TestUncleanShutdownRaceLoser(t *testing.T) { ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) @@ -289,8 +291,8 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { func TestUncleanShutdownRaceWinner(t *testing.T) { ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) @@ -340,9 +342,9 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { func TestRelays(t *testing.T) { ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) @@ -371,9 +373,9 @@ func TestRelays(t *testing.T) { func TestStage1RaceRelays(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) @@ -418,9 +420,9 @@ func TestStage1RaceRelays(t *testing.T) { func TestStage1RaceRelays2(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay @@ -503,5 +505,366 @@ func TestStage1RaceRelays2(t *testing.T) { // ////TODO: assert hostmaps } +func TestRehandshakingRelays(t *testing.T) { + ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake from me to them via the relay") + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) + + // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, + // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. + r.Log("Renew relay certificate and spin until me and them sees it") + _, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + + caB, err := ca.MarshalToPEM() + if err != nil { + panic(err) + } + + relayConfig.Settings["pki"] = m{ + "ca": string(caB), + "cert": string(myNextPEM), + "key": string(myNextPrivKey), + } + rc, err := yaml.Marshal(relayConfig.Settings) + assert.NoError(t, err) + relayConfig.ReloadConfigString(string(rc)) + + for { + r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") + assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + if len(c.Cert.Details.Groups) != 0 { + // We have a new certificate now + r.Log("Certificate between my and relay is updated!") + break + } + + time.Sleep(time.Second) + } + + for { + r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") + assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + if len(c.Cert.Details.Groups) != 0 { + // We have a new certificate now + r.Log("Certificate between their and relay is updated!") + break + } + + time.Sleep(time.Second) + } + + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) + // We should have two hostinfos on all sides + for len(myControl.GetHostmap().Indexes) != 2 { + t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.Log("yupitdoes") + time.Sleep(time.Second) + } + t.Logf("myControl hostinfos got cleaned up!") + for len(theirControl.GetHostmap().Indexes) != 2 { + t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.Log("yupitdoes") + time.Sleep(time.Second) + } + t.Logf("theirControl hostinfos got cleaned up!") + for len(relayControl.GetHostmap().Indexes) != 2 { + t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) + r.Log("Assert the relay tunnel still works") + assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + r.Log("yupitdoes") + time.Sleep(time.Second) + } + t.Logf("relayControl hostinfos got cleaned up!") +} + +func TestRehandshaking(t *testing.T) { + ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) + + // Put their info in our lighthouse and vice versa + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + theirControl.Start() + + t.Log("Stand up a tunnel between me and them") + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + + r.RenderHostmaps("Starting hostmaps", myControl, theirControl) + + r.Log("Renew my certificate and spin until their sees it") + _, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) + + caB, err := ca.MarshalToPEM() + if err != nil { + panic(err) + } + + myConfig.Settings["pki"] = m{ + "ca": string(caB), + "cert": string(myNextPEM), + "key": string(myNextPrivKey), + } + rc, err := yaml.Marshal(myConfig.Settings) + assert.NoError(t, err) + myConfig.ReloadConfigString(string(rc)) + + for { + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + if len(c.Cert.Details.Groups) != 0 { + // We have a new certificate now + break + } + + time.Sleep(time.Second) + } + + // Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly + rc, err = yaml.Marshal(theirConfig.Settings) + assert.NoError(t, err) + var theirNewConfig m + assert.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) + theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{}) + theirFirewall["inbound"] = []m{{ + "proto": "any", + "port": "any", + "group": "new group", + }} + rc, err = yaml.Marshal(theirNewConfig) + assert.NoError(t, err) + theirConfig.ReloadConfigString(string(rc)) + + r.Log("Spin until there is only 1 tunnel") + for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + t.Log("Connection manager hasn't ticked yet") + time.Sleep(time.Second) + } + + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + myFinalHostmapHosts := myControl.ListHostmapHosts(false) + myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) + theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) + theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) + + // Make sure the correct tunnel won + c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + assert.Contains(t, c.Cert.Details.Groups, "new group") + + // We should only have a single tunnel now on both sides + assert.Len(t, myFinalHostmapHosts, 1) + assert.Len(t, theirFinalHostmapHosts, 1) + assert.Len(t, myFinalHostmapIndexes, 1) + assert.Len(t, theirFinalHostmapIndexes, 1) + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + + myControl.Stop() + theirControl.Stop() +} + +func TestRehandshakingLoser(t *testing.T) { + // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel + // Should be the one with the new certificate + ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) + + // Put their info in our lighthouse and vice versa + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + theirControl.Start() + + t.Log("Stand up a tunnel between me and them") + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + + tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + fmt.Println(tt1.LocalIndex, tt2.LocalIndex) + + r.RenderHostmaps("Starting hostmaps", myControl, theirControl) + + r.Log("Renew their certificate and spin until mine sees it") + _, _, theirNextPrivKey, theirNextPEM := newTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) + + caB, err := ca.MarshalToPEM() + if err != nil { + panic(err) + } + + theirConfig.Settings["pki"] = m{ + "ca": string(caB), + "cert": string(theirNextPEM), + "key": string(theirNextPrivKey), + } + rc, err := yaml.Marshal(theirConfig.Settings) + assert.NoError(t, err) + theirConfig.ReloadConfigString(string(rc)) + + for { + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + + _, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"] + if theirNewGroup { + break + } + + time.Sleep(time.Second) + } + + // Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly + rc, err = yaml.Marshal(myConfig.Settings) + assert.NoError(t, err) + var myNewConfig m + assert.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) + theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{}) + theirFirewall["inbound"] = []m{{ + "proto": "any", + "port": "any", + "group": "their new group", + }} + rc, err = yaml.Marshal(myNewConfig) + assert.NoError(t, err) + myConfig.ReloadConfigString(string(rc)) + + r.Log("Spin until there is only 1 tunnel") + for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + t.Log("Connection manager hasn't ticked yet") + time.Sleep(time.Second) + } + + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + myFinalHostmapHosts := myControl.ListHostmapHosts(false) + myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) + theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) + theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) + + // Make sure the correct tunnel won + theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group") + + // We should only have a single tunnel now on both sides + assert.Len(t, myFinalHostmapHosts, 1) + assert.Len(t, theirFinalHostmapHosts, 1) + assert.Len(t, myFinalHostmapIndexes, 1) + assert.Len(t, theirFinalHostmapIndexes, 1) + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + myControl.Stop() + theirControl.Stop() +} + +func TestRaceRegression(t *testing.T) { + // This test forces stage 1, stage 2, stage 1 to be received by me from them + // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which + // caused a cross-linked hostinfo + ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + + // Put their info in our lighthouse + myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + //them rx stage:1 initiatorIndex=642843150 responderIndex=0 + //me rx stage:1 initiatorIndex=120607833 responderIndex=0 + //them rx stage:1 initiatorIndex=642843150 responderIndex=0 + //me rx stage:2 initiatorIndex=642843150 responderIndex=3701775874 + //me rx stage:1 initiatorIndex=120607833 responderIndex=0 + //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 + + t.Log("Start both handshakes") + myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + + t.Log("Get both stage 1") + myStage1ForThem := myControl.GetFromUDP(true) + theirStage1ForMe := theirControl.GetFromUDP(true) + + t.Log("Inject them in a special way") + theirControl.InjectUDPPacket(myStage1ForThem) + myControl.InjectUDPPacket(theirStage1ForMe) + theirControl.InjectUDPPacket(myStage1ForThem) + + //TODO: ensure stage 2 + t.Log("Get both stage 2") + myStage2ForThem := myControl.GetFromUDP(true) + theirStage2ForMe := theirControl.GetFromUDP(true) + + t.Log("Inject them in a special way again") + myControl.InjectUDPPacket(theirStage2ForMe) + myControl.InjectUDPPacket(theirStage1ForMe) + theirControl.InjectUDPPacket(myStage2ForThem) + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + t.Log("Flush the packets") + r.RouteForAllUntilTxTun(myControl) + r.RouteForAllUntilTxTun(theirControl) + r.RenderHostmaps("Starting hostmaps", myControl, theirControl) + + t.Log("Make sure the tunnel still works") + assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() +} + +//TODO: test +// Race winner renews and handshakes +// Race loser renews and handshakes +// Does race winner repin the cert to old? //TODO: add a test with many lies diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index ff1347f..e143feb 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -30,7 +30,7 @@ import ( type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr) { +func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) { l := NewTestLogger() vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} @@ -78,8 +78,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u "level": l.Level.String(), }, "timers": m{ - "pending_deletion_interval": 4, - "connection_alive_interval": 4, + "pending_deletion_interval": 2, + "connection_alive_interval": 2, }, } @@ -105,7 +105,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u panic(err) } - return control, vpnIpNet, &udpAddr + return control, vpnIpNet, &udpAddr, c } // newTestCaCert will generate a CA cert diff --git a/e2e/router/router.go b/e2e/router/router.go index 98bb31d..730853a 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -215,7 +215,7 @@ func (r *R) renderFlow() { continue } participants[addr] = struct{}{} - sanAddr := strings.Replace(addr, ":", "#58;", 1) + sanAddr := strings.Replace(addr, ":", "-", 1) participantsVals = append(participantsVals, sanAddr) fmt.Fprintf( f, " participant %s as Nebula: %s
UDP: %s\n", @@ -252,9 +252,9 @@ func (r *R) renderFlow() { fmt.Fprintf(f, " %s%s%s: %s(%s), index %v, counter: %v\n", - strings.Replace(p.from.GetUDPAddr(), ":", "#58;", 1), + strings.Replace(p.from.GetUDPAddr(), ":", "-", 1), line, - strings.Replace(p.to.GetUDPAddr(), ":", "#58;", 1), + strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, ) } @@ -758,8 +758,8 @@ func (r *R) formatUdpPacket(p *packet) string { data := packet.ApplicationLayer() return fmt.Sprintf( " %s-->>%s: src port: %v
dest port: %v
data: \"%v\"\n", - strings.Replace(from, ":", "#58;", 1), - strings.Replace(p.to.GetUDPAddr(), ":", "#58;", 1), + strings.Replace(from, ":", "-", 1), + strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), udp.SrcPort, udp.DstPort, string(data.Payload()), diff --git a/handshake_manager.go b/handshake_manager.go index ce2811b..b02fb28 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -231,7 +231,8 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light WithError(err). Error("Failed to marshal Control message to create relay") } else { - f.SendMessageToVpnIp(header.Control, 0, *relay, msg, make([]byte, 12), make([]byte, mtu)) + // This must send over the hostinfo, not over hm.Hosts[ip] + f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) c.l.WithFields(logrus.Fields{ "relayFrom": c.lightHouse.myVpnIp, "relayTo": vpnIp, @@ -266,7 +267,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light WithError(err). Error("Failed to marshal Control message to create relay") } else { - f.SendMessageToVpnIp(header.Control, 0, *relay, msg, make([]byte, 12), make([]byte, mtu)) + f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) c.l.WithFields(logrus.Fields{ "relayFrom": c.lightHouse.myVpnIp, "relayTo": vpnIp, @@ -328,8 +329,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket testHostInfo := existingHostInfo for testHostInfo != nil { // Is it just a delayed handshake packet? - if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) { - return existingHostInfo, ErrAlreadySeen + if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], testHostInfo.HandshakePacket[handshakePacket]) { + return testHostInfo, ErrAlreadySeen } testHostInfo = testHostInfo.next diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 3be8a1b..5635c40 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -88,4 +88,8 @@ func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte return } +func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { + return +} + func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {} diff --git a/hostmap.go b/hostmap.go index ebfb840..e5949ad 100644 --- a/hostmap.go +++ b/hostmap.go @@ -32,6 +32,7 @@ const RoamingSuppressSeconds = 2 const ( Requested = iota + PeerRequested Established ) @@ -79,6 +80,16 @@ func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) { delete(rs.relays, ip) } +func (rs *RelayState) CopyAllRelayFor() []*Relay { + rs.RLock() + defer rs.RUnlock() + ret := make([]*Relay, 0, len(rs.relayForByIdx)) + for _, r := range rs.relayForByIdx { + ret = append(ret, r) + } + return ret +} + func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) { rs.RLock() defer rs.RUnlock() @@ -279,29 +290,13 @@ func (hm *HostMap) EmitStats(name string) { func (hm *HostMap) RemoveRelay(localIdx uint32) { hm.Lock() - hiRelay, ok := hm.Relays[localIdx] + _, ok := hm.Relays[localIdx] if !ok { hm.Unlock() return } delete(hm.Relays, localIdx) hm.Unlock() - ip, ok := hiRelay.relayState.RemoveRelay(localIdx) - if !ok { - return - } - hiPeer, err := hm.QueryVpnIp(ip) - if err != nil { - return - } - var otherPeerIdx uint32 - hiPeer.relayState.DeleteRelay(hiRelay.vpnIp) - relay, ok := hiPeer.relayState.GetRelayForByIp(hiRelay.vpnIp) - if ok { - otherPeerIdx = relay.LocalIndex - } - // I am a relaying host. I need to remove the other relay, too. - hm.RemoveRelay(otherPeerIdx) } func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) { @@ -395,29 +390,6 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { hm.unlockedDeleteHostInfo(hostinfo) hm.Unlock() - // And tear down all the relays going through this host, if final - for _, localIdx := range hostinfo.relayState.CopyRelayForIdxs() { - hm.RemoveRelay(localIdx) - } - - if final { - // And tear down the relays this deleted hostInfo was using to be reached - teardownRelayIdx := []uint32{} - for _, relayIp := range hostinfo.relayState.CopyRelayIps() { - relayHostInfo, err := hm.QueryVpnIp(relayIp) - if err != nil { - hm.l.WithError(err).WithField("relay", relayIp).Info("Missing relay host in hostmap") - } else { - if r, ok := relayHostInfo.relayState.QueryRelayForByIp(hostinfo.vpnIp); ok { - teardownRelayIdx = append(teardownRelayIdx, r.LocalIndex) - } - } - } - for _, localIdx := range teardownRelayIdx { - hm.RemoveRelay(localIdx) - } - } - return final } @@ -508,6 +480,10 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Hostmap hostInfo deleted") } + + for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() { + delete(hm.Relays, localRelayIdx) + } } func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) { @@ -562,6 +538,24 @@ func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) { return hm.queryVpnIp(vpnIp, nil) } +func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) { + hm.RLock() + defer hm.RUnlock() + + h, ok := hm.Hosts[relayHostIp] + if !ok { + return nil, nil, errors.New("unable to find host") + } + for h != nil { + r, ok := h.relayState.QueryRelayForByIp(targetIp) + if ok && r.State == Established { + return h, r, nil + } + h = h.next + } + return nil, nil, errors.New("unable to find host with relay") +} + // PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every // `PromoteEvery` calls to this function for a given host. func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) { @@ -709,7 +703,6 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) { i.packetStore = make([]*cachedPacket, 0) i.ConnectionState.ready = true i.ConnectionState.queueLock.Unlock() - i.ConnectionState.certState = nil } func (i *HostInfo) GetCert() *cert.NebulaCertificate { diff --git a/inside.go b/inside.go index 9c40251..18148b6 100644 --- a/inside.go +++ b/inside.go @@ -57,7 +57,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } ci := hostinfo.ConnectionState - if ci.ready == false { + if !ci.ready { // Because we might be sending stored packets, lock here to stop new things going to // the packet queue. ci.queueLock.Lock() @@ -177,7 +177,7 @@ func (f *Interface) initHostInfo(hostinfo *HostInfo) { hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0) } -func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) { +func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { fp := &firewall.Packet{} err := newPacket(p, false, fp) if err != nil { @@ -186,7 +186,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(p, *fp, false, hostInfo, f.caPool, nil) + dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.caPool, nil) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). @@ -196,7 +196,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp return } - f.sendNoMetrics(header.Message, st, hostInfo.ConnectionState, hostInfo, nil, p, nb, out, 0) + f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0) } // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp @@ -215,19 +215,18 @@ func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSu // the packet queue. hostInfo.ConnectionState.queueLock.Lock() if !hostInfo.ConnectionState.ready { - hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToVpnIp, f.cachedPacketMetrics) + hostInfo.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) hostInfo.ConnectionState.queueLock.Unlock() return } hostInfo.ConnectionState.queueLock.Unlock() } - f.sendMessageToVpnIp(t, st, hostInfo, p, nb, out) - return + f.SendMessageToHostInfo(t, st, hostInfo, p, nb, out) } -func (f *Interface) sendMessageToVpnIp(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) { - f.send(t, st, hostInfo.ConnectionState, hostInfo, p, nb, out) +func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hi *HostInfo, p, nb, out []byte) { + f.send(t, st, hi.ConnectionState, hi, p, nb, out) } func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) { @@ -302,6 +301,7 @@ func (f *Interface) SendVia(via *HostInfo, if err != nil { via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia") } + f.connectionManager.RelayUsed(relay.LocalIndex) } func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) { @@ -372,31 +372,19 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } else { // Try to send via a relay for _, relayIP := range hostinfo.relayState.CopyRelayIps() { - relayHostInfo, err := f.hostMap.QueryVpnIp(relayIP) + relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP) if err != nil { + hostinfo.relayState.DeleteRelay(relayIP) hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") continue } - relay, ok := relayHostInfo.relayState.QueryRelayForByIp(hostinfo.vpnIp) - if !ok { - hostinfo.logger(f.l). - WithField("relay", relayHostInfo.vpnIp). - WithField("relayTo", hostinfo.vpnIp). - Info("sendNoMetrics relay missing object for target") - continue - } f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) break } } - return } func isMulticast(ip iputil.VpnIp) bool { // Class D multicast - if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 { - return true - } - - return false + return (((ip >> 24) & 0xff) & 0xf0) == 0xe0 } diff --git a/interface.go b/interface.go index 09c75ee..220cb25 100644 --- a/interface.go +++ b/interface.go @@ -99,6 +99,7 @@ type EncWriter interface { nocopy bool, ) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) + SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) Handshake(vpnIp iputil.VpnIp) } diff --git a/lighthouse_test.go b/lighthouse_test.go index 1824463..658c087 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -377,6 +377,23 @@ func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) { } +func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) { + msg := &NebulaMeta{} + err := msg.Unmarshal(p) + if tw.metaFilter == nil || msg.Type == *tw.metaFilter { + tw.lastReply = testLhReply{ + nebType: t, + nebSubType: st, + vpnIp: hostinfo.vpnIp, + msg: msg, + } + } + + if err != nil { + panic(err) + } +} + func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) { msg := &NebulaMeta{} err := msg.Unmarshal(p) diff --git a/outside.go b/outside.go index 8361ce3..19f5931 100644 --- a/outside.go +++ b/outside.go @@ -83,7 +83,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt switch h.Subtype { case header.MessageNone: - f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) + if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { + return + } case header.MessageRelay: // The entire body is sent as AD, not encrypted. // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. @@ -100,7 +102,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt signedPayload = signedPayload[header.Len:] // Pull the Roaming parts up here, and return in all call paths. f.handleHostRoaming(hostinfo, addr) + // Track usage of both the HostInfo and the Relay for the received & authenticated packet f.connectionManager.In(hostinfo.localIndexId) + f.connectionManager.RelayUsed(h.RemoteIndex) relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) if !ok { @@ -118,17 +122,11 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt return case ForwardingType: // Find the target HostInfo relay object - targetHI, err := f.hostMap.QueryVpnIp(relay.PeerIp) + targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp) if err != nil { hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip") return } - // find the target Relay info object - targetRelay, ok := targetHI.relayState.QueryRelayForByIp(hostinfo.vpnIp) - if !ok { - hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp}).Info("Failed to find relay in hostinfo") - return - } // If that relay is Established, forward the payload through it if targetRelay.State == Established { @@ -382,7 +380,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return out, nil } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { var err error out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) @@ -390,20 +388,20 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") //TODO: maybe after build 64 is out? 06/14/2018 - NB //f.sendRecvError(hostinfo.remote, header.RemoteIndex) - return + return false } err = newPacket(out, true, fwPacket) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("packet", out). Warnf("Error while validating inbound packet") - return + return false } if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { hostinfo.logger(f.l).WithField("fwPacket", fwPacket). Debugln("dropping out of window packet") - return + return false } dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache) @@ -414,7 +412,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out WithField("reason", dropReason). Debugln("dropping inbound packet") } - return + return false } f.connectionManager.In(hostinfo.localIndexId) @@ -422,6 +420,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out if err != nil { f.l.WithError(err).Error("Failed to write to tun") } + return true } func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) { diff --git a/punchy.go b/punchy.go index c0bdbd3..2034405 100644 --- a/punchy.go +++ b/punchy.go @@ -75,7 +75,7 @@ func (p *Punchy) reload(c *config.C, initial bool) { } if initial || c.HasChanged("punchy.target_all_remotes") { - p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", true)) + p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false)) if !initial { p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed") } diff --git a/relay_manager.go b/relay_manager.go index bf75708..fb90eec 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -141,27 +141,29 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo") return } - peerRelay.State = Established - resp := NebulaControl{ - Type: NebulaControl_CreateRelayResponse, - ResponderRelayIndex: peerRelay.LocalIndex, - InitiatorRelayIndex: peerRelay.RemoteIndex, - RelayFromIp: uint32(peerHostInfo.vpnIp), - RelayToIp: uint32(target), - } - msg, err := resp.Marshal() - if err != nil { - rm.l. - WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") - } else { - f.SendMessageToVpnIp(header.Control, 0, peerHostInfo.vpnIp, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), - "initiatorRelayIndex": resp.InitiatorRelayIndex, - "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": peerHostInfo.vpnIp}). - Info("send CreateRelayResponse") + if peerRelay.State == PeerRequested { + peerRelay.State = Established + resp := NebulaControl{ + Type: NebulaControl_CreateRelayResponse, + ResponderRelayIndex: peerRelay.LocalIndex, + InitiatorRelayIndex: peerRelay.RemoteIndex, + RelayFromIp: uint32(peerHostInfo.vpnIp), + RelayToIp: uint32(target), + } + msg, err := resp.Marshal() + if err != nil { + rm.l. + WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + } else { + f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.WithFields(logrus.Fields{ + "relayFrom": iputil.VpnIp(resp.RelayFromIp), + "relayTo": iputil.VpnIp(resp.RelayToIp), + "initiatorRelayIndex": resp.InitiatorRelayIndex, + "responderRelayIndex": resp.ResponderRelayIndex, + "vpnIp": peerHostInfo.vpnIp}). + Info("send CreateRelayResponse") + } } } @@ -223,7 +225,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N logMsg. WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") } else { - f.SendMessageToVpnIp(header.Control, 0, h.vpnIp, msg, make([]byte, 12), make([]byte, mtu)) + f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ "relayFrom": iputil.VpnIp(resp.RelayFromIp), "relayTo": iputil.VpnIp(resp.RelayToIp), @@ -278,7 +280,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N logMsg. WithError(err).Error("relayManager Failed to marshal Control message to create relay") } else { - f.SendMessageToVpnIp(header.Control, 0, target, msg, make([]byte, 12), make([]byte, mtu)) + f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ "relayFrom": iputil.VpnIp(req.RelayFromIp), "relayTo": iputil.VpnIp(req.RelayToIp), @@ -292,7 +294,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N relay, ok := h.relayState.QueryRelayForByIp(target) if !ok { // Add the relay - state := Requested + state := PeerRequested if targetRelay != nil && targetRelay.State == Established { state = Established } @@ -324,7 +326,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N rm.l. WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") } else { - f.SendMessageToVpnIp(header.Control, 0, h.vpnIp, msg, make([]byte, 12), make([]byte, mtu)) + f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ "relayFrom": iputil.VpnIp(resp.RelayFromIp), "relayTo": iputil.VpnIp(resp.RelayToIp),