We only need the certificate in ConnectionState (#953)

This commit is contained in:
Nate Brown 2023-08-21 14:11:06 -05:00 committed by GitHub
parent 5a131b2975
commit 7edcf620c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 37 additions and 51 deletions

View File

@ -406,7 +406,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
}
certState := n.intf.pki.GetCertState()
return bytes.Equal(current.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature)
return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
}
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
@ -465,7 +465,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
certState := n.intf.pki.GetCertState()
if bytes.Equal(hostinfo.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) {
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
return
}
@ -474,7 +474,7 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
Info("Re-handshaking with remote")
//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo)
newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp)
if !newHostinfo.HandshakeReady {
ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
}

View File

@ -79,8 +79,8 @@ func Test_NewConnectionManagerTest(t *testing.T) {
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
H: &noise.HandshakeState{},
myCert: &cert.NebulaCertificate{},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@ -159,8 +159,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
H: &noise.HandshakeState{},
myCert: &cert.NebulaCertificate{},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@ -222,7 +222,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
PublicKey: pubCA,
},
}
caCert.Sign(cert.Curve_CURVE25519, privCA)
assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
ncp := &cert.NebulaCAPool{
CAs: cert.NewCAPool().CAs,
}
@ -241,7 +242,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
Issuer: "ca",
},
}
peerCert.Sign(cert.Curve_CURVE25519, privCA)
assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))
cs := &CertState{
RawCertificate: []byte{},
@ -275,9 +276,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
hostinfo := &HostInfo{
vpnIp: vpnIp,
ConnectionState: &ConnectionState{
certState: cs,
peerCert: &peerCert,
H: &noise.HandshakeState{},
myCert: &cert.NebulaCertificate{},
peerCert: &peerCert,
H: &noise.HandshakeState{},
},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)

View File

@ -18,7 +18,7 @@ type ConnectionState struct {
eKey *NebulaCipherState
dKey *NebulaCipherState
H *noise.HandshakeState
certState *CertState
myCert *cert.NebulaCertificate
peerCert *cert.NebulaCertificate
initiator bool
messageCounter atomic.Uint64
@ -28,25 +28,27 @@ type ConnectionState struct {
ready bool
}
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
var dhFunc noise.DHFunc
curCertState := f.pki.GetCertState()
switch curCertState.Certificate.Details.Curve {
switch certState.Certificate.Details.Curve {
case cert.Curve_CURVE25519:
dhFunc = noise.DH25519
case cert.Curve_P256:
dhFunc = noiseutil.DHP256
default:
l.Errorf("invalid curve: %s", curCertState.Certificate.Details.Curve)
l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
return nil
}
cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
if f.cipher == "chachapoly" {
var cs noise.CipherSuite
if cipher == "chachapoly" {
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
} else {
cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
}
static := noise.DHKey{Private: curCertState.PrivateKey, Public: curCertState.PublicKey}
static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
@ -72,7 +74,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
initiator: initiator,
window: b,
ready: false,
certState: curCertState,
myCert: certState.Certificate,
}
return ci

View File

@ -165,7 +165,7 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
}
func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo)
hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp)
ixHandshakeStage0(c.f, vpnIp, hostinfo)
// If this is a static host, we don't need to wait for the HostQueryReply

View File

@ -28,12 +28,14 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
return
}
ci := hostinfo.ConnectionState
certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
hostinfo.ConnectionState = ci
hsProto := &NebulaHandshakeDetails{
InitiatorIndex: hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: ci.certState.RawCertificateNoKey,
Cert: certState.RawCertificateNoKey,
}
hsBytes := []byte{}
@ -69,7 +71,8 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
}
func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
@ -155,7 +158,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
Info("Handshake message received")
hs.Details.ResponderIndex = myIndex
hs.Details.Cert = ci.certState.RawCertificateNoKey
hs.Details.Cert = certState.RawCertificateNoKey
// Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano())

View File

@ -297,7 +297,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
}
// AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it.
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo {
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
// A write lock is used to avoid having to recheck the map and trading a read lock for a write lock
c.Lock()
defer c.Unlock()
@ -317,10 +317,6 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H
},
}
if init != nil {
init(hostinfo)
}
c.vpnIps[vpnIp] = hostinfo
c.metricInitiated.Inc(1)
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)

View File

@ -28,17 +28,8 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
var initCalled bool
initFunc := func(*HostInfo) {
initCalled = true
}
i := blah.AddVpnIp(ip, initFunc)
assert.True(t, initCalled)
initCalled = false
i2 := blah.AddVpnIp(ip, initFunc)
assert.False(t, initCalled)
i := blah.AddVpnIp(ip)
i2 := blah.AddVpnIp(ip)
assert.Same(t, i, i2)
i.remotes = NewRemoteList(nil)

View File

@ -1,7 +1,6 @@
package nebula
import (
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
@ -124,7 +123,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
if hostinfo == nil {
hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
hostinfo = f.handshakeManager.AddVpnIp(vpnIp)
}
ci := hostinfo.ConnectionState
@ -168,12 +167,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
return hostinfo
}
// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that
// will create the initial Noise ConnectionState
func (f *Interface) initHostInfo(hostinfo *HostInfo) {
hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
}
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
fp := &firewall.Packet{}
err := newPacket(p, false, fp)

2
ssh.go
View File

@ -607,7 +607,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
}
}
hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo)
hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp)
if addr != nil {
hostInfo.SetRemote(addr)
}