mirror of https://github.com/slackhq/nebula.git
Use atomic.Pointer for certState (#833)
This commit is contained in:
parent
2801fb2286
commit
6b3d42efa5
|
@ -54,12 +54,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||
hostMap: hostMap,
|
||||
inside: &test.NoopTun{},
|
||||
outside: &udp.Conn{},
|
||||
certState: cs,
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
|
||||
l: l,
|
||||
}
|
||||
ifce.certState.Store(cs)
|
||||
now := time.Now()
|
||||
|
||||
// Create manager
|
||||
|
@ -130,12 +130,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||
hostMap: hostMap,
|
||||
inside: &test.NoopTun{},
|
||||
outside: &udp.Conn{},
|
||||
certState: cs,
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
|
||||
l: l,
|
||||
}
|
||||
ifce.certState.Store(cs)
|
||||
now := time.Now()
|
||||
|
||||
// Create manager
|
||||
|
@ -245,7 +245,6 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||
hostMap: hostMap,
|
||||
inside: &test.NoopTun{},
|
||||
outside: &udp.Conn{},
|
||||
certState: cs,
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
|
||||
|
@ -253,6 +252,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||
disconnectInvalid: true,
|
||||
caPool: ncp,
|
||||
}
|
||||
ifce.certState.Store(cs)
|
||||
|
||||
// Create manager
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
|
|
@ -33,7 +33,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
|
|||
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
}
|
||||
|
||||
curCertState := f.certState
|
||||
curCertState := f.certState.Load()
|
||||
static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey}
|
||||
|
||||
b := NewBits(ReplayWindow)
|
||||
|
|
|
@ -161,5 +161,5 @@ func (c *Control) GetHostmap() *HostMap {
|
|||
}
|
||||
|
||||
func (c *Control) GetCert() *cert.NebulaCertificate {
|
||||
return c.f.certState.certificate
|
||||
return c.f.certState.Load().certificate
|
||||
}
|
||||
|
|
11
interface.go
11
interface.go
|
@ -52,7 +52,7 @@ type Interface struct {
|
|||
hostMap *HostMap
|
||||
outside *udp.Conn
|
||||
inside overlay.Device
|
||||
certState *CertState
|
||||
certState atomic.Pointer[CertState]
|
||||
cipher string
|
||||
firewall *Firewall
|
||||
connectionManager *connectionManager
|
||||
|
@ -141,7 +141,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||
hostMap: c.HostMap,
|
||||
outside: c.Outside,
|
||||
inside: c.Inside,
|
||||
certState: c.certState,
|
||||
cipher: c.Cipher,
|
||||
firewall: c.Firewall,
|
||||
serveDns: c.ServeDns,
|
||||
|
@ -172,6 +171,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||
l: c.l,
|
||||
}
|
||||
|
||||
ifce.certState.Store(c.certState)
|
||||
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
|
||||
|
||||
return ifce, nil
|
||||
|
@ -298,14 +298,15 @@ func (f *Interface) reloadCertKey(c *config.C) {
|
|||
}
|
||||
|
||||
// did IP in cert change? if so, don't set
|
||||
oldIPs := f.certState.certificate.Details.Ips
|
||||
currentCert := f.certState.Load().certificate
|
||||
oldIPs := currentCert.Details.Ips
|
||||
newIPs := cs.certificate.Details.Ips
|
||||
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
|
||||
f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
|
||||
return
|
||||
}
|
||||
|
||||
f.certState = cs
|
||||
f.certState.Store(cs)
|
||||
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
|
||||
}
|
||||
|
||||
|
@ -316,7 +317,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||
return
|
||||
}
|
||||
|
||||
fw, err := NewFirewallFromConfig(f.l, f.certState.certificate, c)
|
||||
fw, err := NewFirewallFromConfig(f.l, f.certState.Load().certificate, c)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Error while creating firewall during reload")
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue