Use atomic.Pointer for certState (#833)

This commit is contained in:
Nate Brown 2023-03-30 13:04:09 -05:00 committed by GitHub
parent 2801fb2286
commit 6b3d42efa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 12 additions and 11 deletions

View File

@ -54,12 +54,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
hostMap: hostMap,
inside: &test.NoopTun{},
outside: &udp.Conn{},
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
l: l,
}
ifce.certState.Store(cs)
now := time.Now()
// Create manager
@ -130,12 +130,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
hostMap: hostMap,
inside: &test.NoopTun{},
outside: &udp.Conn{},
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
l: l,
}
ifce.certState.Store(cs)
now := time.Now()
// Create manager
@ -245,7 +245,6 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
hostMap: hostMap,
inside: &test.NoopTun{},
outside: &udp.Conn{},
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
@ -253,6 +252,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
disconnectInvalid: true,
caPool: ncp,
}
ifce.certState.Store(cs)
// Create manager
ctx, cancel := context.WithCancel(context.Background())

View File

@ -33,7 +33,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
}
curCertState := f.certState
curCertState := f.certState.Load()
static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey}
b := NewBits(ReplayWindow)

View File

@ -161,5 +161,5 @@ func (c *Control) GetHostmap() *HostMap {
}
func (c *Control) GetCert() *cert.NebulaCertificate {
return c.f.certState.certificate
return c.f.certState.Load().certificate
}

View File

@ -52,7 +52,7 @@ type Interface struct {
hostMap *HostMap
outside *udp.Conn
inside overlay.Device
certState *CertState
certState atomic.Pointer[CertState]
cipher string
firewall *Firewall
connectionManager *connectionManager
@ -141,7 +141,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
hostMap: c.HostMap,
outside: c.Outside,
inside: c.Inside,
certState: c.certState,
cipher: c.Cipher,
firewall: c.Firewall,
serveDns: c.ServeDns,
@ -172,6 +171,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l,
}
ifce.certState.Store(c.certState)
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
return ifce, nil
@ -298,14 +298,15 @@ func (f *Interface) reloadCertKey(c *config.C) {
}
// did IP in cert change? if so, don't set
oldIPs := f.certState.certificate.Details.Ips
currentCert := f.certState.Load().certificate
oldIPs := currentCert.Details.Ips
newIPs := cs.certificate.Details.Ips
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
return
}
f.certState = cs
f.certState.Store(cs)
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
}
@ -316,7 +317,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
return
}
fw, err := NewFirewallFromConfig(f.l, f.certState.certificate, c)
fw, err := NewFirewallFromConfig(f.l, f.certState.Load().certificate, c)
if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload")
return

2
ssh.go
View File

@ -753,7 +753,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
return nil
}
cert := ifce.certState.certificate
cert := ifce.certState.Load().certificate
if len(a) > 0 {
parsedIp := net.ParseIP(a[0])
if parsedIp == nil {