diff --git a/connection_manager_test.go b/connection_manager_test.go index 51e331b..b02c1bf 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -54,12 +54,12 @@ func Test_NewConnectionManagerTest(t *testing.T) { hostMap: hostMap, inside: &test.NoopTun{}, outside: &udp.Conn{}, - certState: cs, firewall: &Firewall{}, lightHouse: lh, handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig), l: l, } + ifce.certState.Store(cs) now := time.Now() // Create manager @@ -130,12 +130,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) { hostMap: hostMap, inside: &test.NoopTun{}, outside: &udp.Conn{}, - certState: cs, firewall: &Firewall{}, lightHouse: lh, handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig), l: l, } + ifce.certState.Store(cs) now := time.Now() // Create manager @@ -245,7 +245,6 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { hostMap: hostMap, inside: &test.NoopTun{}, outside: &udp.Conn{}, - certState: cs, firewall: &Firewall{}, lightHouse: lh, handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig), @@ -253,6 +252,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { disconnectInvalid: true, caPool: ncp, } + ifce.certState.Store(cs) // Create manager ctx, cancel := context.WithCancel(context.Background()) diff --git a/connection_state.go b/connection_state.go index 6bbb02f..2a7be15 100644 --- a/connection_state.go +++ b/connection_state.go @@ -33,7 +33,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) } - curCertState := f.certState + curCertState := f.certState.Load() static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey} b := NewBits(ReplayWindow) diff --git a/control_tester.go b/control_tester.go index 4fa0763..550c986 100644 --- a/control_tester.go +++ b/control_tester.go @@ -161,5 +161,5 @@ func (c *Control) GetHostmap() *HostMap { } func (c *Control) GetCert() *cert.NebulaCertificate { - return c.f.certState.certificate + return c.f.certState.Load().certificate } diff --git a/interface.go b/interface.go index 632e823..cc6e781 100644 --- a/interface.go +++ b/interface.go @@ -52,7 +52,7 @@ type Interface struct { hostMap *HostMap outside *udp.Conn inside overlay.Device - certState *CertState + certState atomic.Pointer[CertState] cipher string firewall *Firewall connectionManager *connectionManager @@ -141,7 +141,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { hostMap: c.HostMap, outside: c.Outside, inside: c.Inside, - certState: c.certState, cipher: c.Cipher, firewall: c.Firewall, serveDns: c.ServeDns, @@ -172,6 +171,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } + ifce.certState.Store(c.certState) ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval) return ifce, nil @@ -298,14 +298,15 @@ func (f *Interface) reloadCertKey(c *config.C) { } // did IP in cert change? if so, don't set - oldIPs := f.certState.certificate.Details.Ips + currentCert := f.certState.Load().certificate + oldIPs := currentCert.Details.Ips newIPs := cs.certificate.Details.Ips if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old") return } - f.certState = cs + f.certState.Store(cs) f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk") } @@ -316,7 +317,7 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.certState.certificate, c) + fw, err := NewFirewallFromConfig(f.l, f.certState.Load().certificate, c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return diff --git a/ssh.go b/ssh.go index 438fbeb..6223314 100644 --- a/ssh.go +++ b/ssh.go @@ -753,7 +753,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return nil } - cert := ifce.certState.certificate + cert := ifce.certState.Load().certificate if len(a) > 0 { parsedIp := net.ParseIP(a[0]) if parsedIp == nil {