mirror of https://github.com/slackhq/nebula.git
We only need the certificate in ConnectionState (#953)
This commit is contained in:
parent
5a131b2975
commit
7edcf620c0
|
@ -406,7 +406,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
|
|||
}
|
||||
|
||||
certState := n.intf.pki.GetCertState()
|
||||
return bytes.Equal(current.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature)
|
||||
return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
|
||||
}
|
||||
|
||||
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
|
||||
|
@ -465,7 +465,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
|
|||
|
||||
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
certState := n.intf.pki.GetCertState()
|
||||
if bytes.Equal(hostinfo.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) {
|
||||
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -474,7 +474,7 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
|||
Info("Re-handshaking with remote")
|
||||
|
||||
//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
|
||||
newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo)
|
||||
newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp)
|
||||
if !newHostinfo.HandshakeReady {
|
||||
ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
|
||||
}
|
||||
|
|
|
@ -79,8 +79,8 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||
remoteIndexId: 9901,
|
||||
}
|
||||
hostinfo.ConnectionState = &ConnectionState{
|
||||
certState: cs,
|
||||
H: &noise.HandshakeState{},
|
||||
myCert: &cert.NebulaCertificate{},
|
||||
H: &noise.HandshakeState{},
|
||||
}
|
||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||
|
||||
|
@ -159,8 +159,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||
remoteIndexId: 9901,
|
||||
}
|
||||
hostinfo.ConnectionState = &ConnectionState{
|
||||
certState: cs,
|
||||
H: &noise.HandshakeState{},
|
||||
myCert: &cert.NebulaCertificate{},
|
||||
H: &noise.HandshakeState{},
|
||||
}
|
||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||
|
||||
|
@ -222,7 +222,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||
PublicKey: pubCA,
|
||||
},
|
||||
}
|
||||
caCert.Sign(cert.Curve_CURVE25519, privCA)
|
||||
|
||||
assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
|
||||
ncp := &cert.NebulaCAPool{
|
||||
CAs: cert.NewCAPool().CAs,
|
||||
}
|
||||
|
@ -241,7 +242,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||
Issuer: "ca",
|
||||
},
|
||||
}
|
||||
peerCert.Sign(cert.Curve_CURVE25519, privCA)
|
||||
assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))
|
||||
|
||||
cs := &CertState{
|
||||
RawCertificate: []byte{},
|
||||
|
@ -275,9 +276,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||
hostinfo := &HostInfo{
|
||||
vpnIp: vpnIp,
|
||||
ConnectionState: &ConnectionState{
|
||||
certState: cs,
|
||||
peerCert: &peerCert,
|
||||
H: &noise.HandshakeState{},
|
||||
myCert: &cert.NebulaCertificate{},
|
||||
peerCert: &peerCert,
|
||||
H: &noise.HandshakeState{},
|
||||
},
|
||||
}
|
||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||
|
|
|
@ -18,7 +18,7 @@ type ConnectionState struct {
|
|||
eKey *NebulaCipherState
|
||||
dKey *NebulaCipherState
|
||||
H *noise.HandshakeState
|
||||
certState *CertState
|
||||
myCert *cert.NebulaCertificate
|
||||
peerCert *cert.NebulaCertificate
|
||||
initiator bool
|
||||
messageCounter atomic.Uint64
|
||||
|
@ -28,25 +28,27 @@ type ConnectionState struct {
|
|||
ready bool
|
||||
}
|
||||
|
||||
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
||||
func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
||||
var dhFunc noise.DHFunc
|
||||
curCertState := f.pki.GetCertState()
|
||||
|
||||
switch curCertState.Certificate.Details.Curve {
|
||||
switch certState.Certificate.Details.Curve {
|
||||
case cert.Curve_CURVE25519:
|
||||
dhFunc = noise.DH25519
|
||||
case cert.Curve_P256:
|
||||
dhFunc = noiseutil.DHP256
|
||||
default:
|
||||
l.Errorf("invalid curve: %s", curCertState.Certificate.Details.Curve)
|
||||
l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
|
||||
return nil
|
||||
}
|
||||
cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
|
||||
if f.cipher == "chachapoly" {
|
||||
|
||||
var cs noise.CipherSuite
|
||||
if cipher == "chachapoly" {
|
||||
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
} else {
|
||||
cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
|
||||
}
|
||||
|
||||
static := noise.DHKey{Private: curCertState.PrivateKey, Public: curCertState.PublicKey}
|
||||
static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
|
||||
|
||||
b := NewBits(ReplayWindow)
|
||||
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
|
||||
|
@ -72,7 +74,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
|
|||
initiator: initiator,
|
||||
window: b,
|
||||
ready: false,
|
||||
certState: curCertState,
|
||||
myCert: certState.Certificate,
|
||||
}
|
||||
|
||||
return ci
|
||||
|
|
|
@ -165,7 +165,7 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
|
|||
}
|
||||
|
||||
func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
|
||||
hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo)
|
||||
hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp)
|
||||
ixHandshakeStage0(c.f, vpnIp, hostinfo)
|
||||
|
||||
// If this is a static host, we don't need to wait for the HostQueryReply
|
||||
|
|
|
@ -28,12 +28,14 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
|
|||
return
|
||||
}
|
||||
|
||||
ci := hostinfo.ConnectionState
|
||||
certState := f.pki.GetCertState()
|
||||
ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
|
||||
hostinfo.ConnectionState = ci
|
||||
|
||||
hsProto := &NebulaHandshakeDetails{
|
||||
InitiatorIndex: hostinfo.localIndexId,
|
||||
Time: uint64(time.Now().UnixNano()),
|
||||
Cert: ci.certState.RawCertificateNoKey,
|
||||
Cert: certState.RawCertificateNoKey,
|
||||
}
|
||||
|
||||
hsBytes := []byte{}
|
||||
|
@ -69,7 +71,8 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
|
|||
}
|
||||
|
||||
func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
|
||||
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
|
||||
certState := f.pki.GetCertState()
|
||||
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
|
||||
// Mark packet 1 as seen so it doesn't show up as missed
|
||||
ci.window.Update(f.l, 1)
|
||||
|
||||
|
@ -155,7 +158,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
|
|||
Info("Handshake message received")
|
||||
|
||||
hs.Details.ResponderIndex = myIndex
|
||||
hs.Details.Cert = ci.certState.RawCertificateNoKey
|
||||
hs.Details.Cert = certState.RawCertificateNoKey
|
||||
// Update the time in case their clock is way off from ours
|
||||
hs.Details.Time = uint64(time.Now().UnixNano())
|
||||
|
||||
|
|
|
@ -297,7 +297,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
|
|||
}
|
||||
|
||||
// AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it.
|
||||
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo {
|
||||
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
|
||||
// A write lock is used to avoid having to recheck the map and trading a read lock for a write lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
@ -317,10 +317,6 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H
|
|||
},
|
||||
}
|
||||
|
||||
if init != nil {
|
||||
init(hostinfo)
|
||||
}
|
||||
|
||||
c.vpnIps[vpnIp] = hostinfo
|
||||
c.metricInitiated.Inc(1)
|
||||
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
|
||||
|
|
|
@ -28,17 +28,8 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
|||
now := time.Now()
|
||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||
|
||||
var initCalled bool
|
||||
initFunc := func(*HostInfo) {
|
||||
initCalled = true
|
||||
}
|
||||
|
||||
i := blah.AddVpnIp(ip, initFunc)
|
||||
assert.True(t, initCalled)
|
||||
|
||||
initCalled = false
|
||||
i2 := blah.AddVpnIp(ip, initFunc)
|
||||
assert.False(t, initCalled)
|
||||
i := blah.AddVpnIp(ip)
|
||||
i2 := blah.AddVpnIp(ip)
|
||||
assert.Same(t, i, i2)
|
||||
|
||||
i.remotes = NewRemoteList(nil)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"github.com/flynn/noise"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
|
@ -124,7 +123,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
|
|||
|
||||
hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
|
||||
if hostinfo == nil {
|
||||
hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
|
||||
hostinfo = f.handshakeManager.AddVpnIp(vpnIp)
|
||||
}
|
||||
ci := hostinfo.ConnectionState
|
||||
|
||||
|
@ -168,12 +167,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
|
|||
return hostinfo
|
||||
}
|
||||
|
||||
// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that
|
||||
// will create the initial Noise ConnectionState
|
||||
func (f *Interface) initHostInfo(hostinfo *HostInfo) {
|
||||
hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
|
||||
}
|
||||
|
||||
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
||||
fp := &firewall.Packet{}
|
||||
err := newPacket(p, false, fp)
|
||||
|
|
Loading…
Reference in New Issue