diff --git a/cert.go b/cert.go deleted file mode 100644 index bbd29c6..0000000 --- a/cert.go +++ /dev/null @@ -1,163 +0,0 @@ -package nebula - -import ( - "errors" - "fmt" - "io/ioutil" - "strings" - "time" - - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/config" -) - -type CertState struct { - certificate *cert.NebulaCertificate - rawCertificate []byte - rawCertificateNoKey []byte - publicKey []byte - privateKey []byte -} - -func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) { - // Marshal the certificate to ensure it is valid - rawCertificate, err := certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) - } - - publicKey := certificate.Details.PublicKey - cs := &CertState{ - rawCertificate: rawCertificate, - certificate: certificate, // PublicKey has been set to nil above - privateKey: privateKey, - publicKey: publicKey, - } - - cs.certificate.Details.PublicKey = nil - rawCertNoKey, err := cs.certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("error marshalling certificate no key: %s", err) - } - cs.rawCertificateNoKey = rawCertNoKey - // put public key back - cs.certificate.Details.PublicKey = cs.publicKey - return cs, nil -} - -func NewCertStateFromConfig(c *config.C) (*CertState, error) { - var pemPrivateKey []byte - var err error - - privPathOrPEM := c.GetString("pki.key", "") - - if privPathOrPEM == "" { - return nil, errors.New("no pki.key path or PEM data provided") - } - - if strings.Contains(privPathOrPEM, "-----BEGIN") { - pemPrivateKey = []byte(privPathOrPEM) - privPathOrPEM = "" - } else { - pemPrivateKey, err = ioutil.ReadFile(privPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) - } - } - - rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey) - if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) - } - - var rawCert []byte - - pubPathOrPEM := c.GetString("pki.cert", "") - - if pubPathOrPEM == "" { - return nil, errors.New("no pki.cert path or PEM data provided") - } - - if strings.Contains(pubPathOrPEM, "-----BEGIN") { - rawCert = []byte(pubPathOrPEM) - pubPathOrPEM = "" - } else { - rawCert, err = ioutil.ReadFile(pubPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err) - } - } - - nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) - if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) - } - - if nebulaCert.Expired(time.Now()) { - return nil, fmt.Errorf("nebula certificate for this host is expired") - } - - if len(nebulaCert.Details.Ips) == 0 { - return nil, fmt.Errorf("no IPs encoded in certificate") - } - - if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { - return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") - } - - return NewCertState(nebulaCert, rawKey) -} - -func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) { - var rawCA []byte - var err error - - caPathOrPEM := c.GetString("pki.ca", "") - if caPathOrPEM == "" { - return nil, errors.New("no pki.ca path or PEM data provided") - } - - if strings.Contains(caPathOrPEM, "-----BEGIN") { - rawCA = []byte(caPathOrPEM) - - } else { - rawCA, err = ioutil.ReadFile(caPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) - } - } - - CAs, err := cert.NewCAPoolFromBytes(rawCA) - if errors.Is(err, cert.ErrExpired) { - var expired int - for _, cert := range CAs.CAs { - if cert.Expired(time.Now()) { - expired++ - l.WithField("cert", cert).Warn("expired certificate present in CA pool") - } - } - - if expired >= len(CAs.CAs) { - return nil, errors.New("no valid CA certificates present") - } - - } else if err != nil { - return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) - } - - for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) { - l.WithField("fingerprint", fp).Info("Blocklisting cert") - CAs.BlocklistFingerprint(fp) - } - - // Support deprecated config for at least one minor release to allow for migrations - //TODO: remove in 2022 or later - for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) { - l.WithField("fingerprint", fp).Info("Blocklisting cert") - l.Warn("pki.blacklist is deprecated and will not be supported in a future release. Please migrate your config to use pki.blocklist") - CAs.BlocklistFingerprint(fp) - } - - return CAs, nil -} diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index c1de267..8d0eaa1 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -59,13 +59,8 @@ func main() { } ctrl, err := nebula.Main(c, *configTest, Build, l, nil) - - switch v := err.(type) { - case util.ContextualError: - v.Log(l) - os.Exit(1) - case error: - l.WithError(err).Error("Failed to start") + if err != nil { + util.LogWithContextIfNeeded("Failed to start", err, l) os.Exit(1) } diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index 9461035..5cf0a02 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -53,13 +53,8 @@ func main() { } ctrl, err := nebula.Main(c, *configTest, Build, l, nil) - - switch v := err.(type) { - case util.ContextualError: - v.Log(l) - os.Exit(1) - case error: - l.WithError(err).Error("Failed to start") + if err != nil { + util.LogWithContextIfNeeded("Failed to start", err, l) os.Exit(1) } diff --git a/connection_manager.go b/connection_manager.go index 528cf1b..62a8dd2 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -405,8 +405,8 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { return false } - certState := n.intf.certState.Load() - return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) + certState := n.intf.pki.GetCertState() + return bytes.Equal(current.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) } func (n *connectionManager) swapPrimary(current, primary *HostInfo) { @@ -427,7 +427,7 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn return false } - valid, err := remoteCert.VerifyWithCache(now, n.intf.caPool) + valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool()) if valid { return false } @@ -464,8 +464,8 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { } func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { - certState := n.intf.certState.Load() - if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) { + certState := n.intf.pki.GetCertState() + if bytes.Equal(hostinfo.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) { return } diff --git a/connection_manager_test.go b/connection_manager_test.go index 642e055..a489bf2 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -44,10 +44,10 @@ func Test_NewConnectionManagerTest(t *testing.T) { // Very incomplete mock objects hostMap := NewHostMap(l, vpncidr, preferredRanges) cs := &CertState{ - rawCertificate: []byte{}, - privateKey: []byte{}, - certificate: &cert.NebulaCertificate{}, - rawCertificateNoKey: []byte{}, + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, } lh := newTestLighthouse() @@ -57,10 +57,11 @@ func Test_NewConnectionManagerTest(t *testing.T) { outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, + pki: &PKI{}, handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, } - ifce.certState.Store(cs) + ifce.pki.cs.Store(cs) // Create manager ctx, cancel := context.WithCancel(context.Background()) @@ -123,10 +124,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // Very incomplete mock objects hostMap := NewHostMap(l, vpncidr, preferredRanges) cs := &CertState{ - rawCertificate: []byte{}, - privateKey: []byte{}, - certificate: &cert.NebulaCertificate{}, - rawCertificateNoKey: []byte{}, + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, } lh := newTestLighthouse() @@ -136,10 +137,11 @@ func Test_NewConnectionManagerTest2(t *testing.T) { outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, + pki: &PKI{}, handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, } - ifce.certState.Store(cs) + ifce.pki.cs.Store(cs) // Create manager ctx, cancel := context.WithCancel(context.Background()) @@ -242,10 +244,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { peerCert.Sign(cert.Curve_CURVE25519, privCA) cs := &CertState{ - rawCertificate: []byte{}, - privateKey: []byte{}, - certificate: &cert.NebulaCertificate{}, - rawCertificateNoKey: []byte{}, + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, } lh := newTestLighthouse() @@ -258,9 +260,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, disconnectInvalid: true, - caPool: ncp, + pki: &PKI{}, } - ifce.certState.Store(cs) + ifce.pki.cs.Store(cs) + ifce.pki.caPool.Store(ncp) // Create manager ctx, cancel := context.WithCancel(context.Background()) diff --git a/connection_state.go b/connection_state.go index ab818c9..163e4bc 100644 --- a/connection_state.go +++ b/connection_state.go @@ -30,15 +30,15 @@ type ConnectionState struct { func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { var dhFunc noise.DHFunc - curCertState := f.certState.Load() + curCertState := f.pki.GetCertState() - switch curCertState.certificate.Details.Curve { + switch curCertState.Certificate.Details.Curve { case cert.Curve_CURVE25519: dhFunc = noise.DH25519 case cert.Curve_P256: dhFunc = noiseutil.DHP256 default: - l.Errorf("invalid curve: %s", curCertState.certificate.Details.Curve) + l.Errorf("invalid curve: %s", curCertState.Certificate.Details.Curve) return nil } cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) @@ -46,7 +46,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) } - static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey} + static := noise.DHKey{Private: curCertState.PrivateKey, Public: curCertState.PublicKey} b := NewBits(ReplayWindow) // Clear out bit 0, we never transmit it and we don't want it showing as packet loss diff --git a/control_tester.go b/control_tester.go index dd1a774..a26c8bb 100644 --- a/control_tester.go +++ b/control_tester.go @@ -161,7 +161,7 @@ func (c *Control) GetHostmap() *HostMap { } func (c *Control) GetCert() *cert.NebulaCertificate { - return c.f.certState.Load().certificate + return c.f.pki.GetCertState().Certificate } func (c *Control) ReHandshake(vpnIp iputil.VpnIp) { diff --git a/handshake_ix.go b/handshake_ix.go index 70263b9..52efdf5 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -33,7 +33,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { hsProto := &NebulaHandshakeDetails{ InitiatorIndex: hostinfo.localIndexId, Time: uint64(time.Now().UnixNano()), - Cert: ci.certState.rawCertificateNoKey, + Cert: ci.certState.RawCertificateNoKey, } hsBytes := []byte{} @@ -91,7 +91,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by return } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool) + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) if err != nil { f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). @@ -155,7 +155,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by Info("Handshake message received") hs.Details.ResponderIndex = myIndex - hs.Details.Cert = ci.certState.rawCertificateNoKey + hs.Details.Cert = ci.certState.RawCertificateNoKey // Update the time in case their clock is way off from ours hs.Details.Time = uint64(time.Now().UnixNano()) @@ -399,7 +399,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H return true } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool) + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) if err != nil { f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). diff --git a/inside.go b/inside.go index 0d43926..0fac833 100644 --- a/inside.go +++ b/inside.go @@ -69,7 +69,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet ci.queueLock.Unlock() } - dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache) + dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, packet, nb, out, q) @@ -183,7 +183,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.caPool, nil) + dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). diff --git a/interface.go b/interface.go index 771aed0..fbf610a 100644 --- a/interface.go +++ b/interface.go @@ -13,7 +13,6 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -28,7 +27,7 @@ type InterfaceConfig struct { HostMap *HostMap Outside udp.Conn Inside overlay.Device - certState *CertState + pki *PKI Cipher string Firewall *Firewall ServeDns bool @@ -41,7 +40,6 @@ type InterfaceConfig struct { routines int MessageMetrics *MessageMetrics version string - caPool *cert.NebulaCAPool disconnectInvalid bool relayManager *relayManager punchy *Punchy @@ -58,7 +56,7 @@ type Interface struct { hostMap *HostMap outside udp.Conn inside overlay.Device - certState atomic.Pointer[CertState] + pki *PKI cipher string firewall *Firewall connectionManager *connectionManager @@ -71,7 +69,6 @@ type Interface struct { dropLocalBroadcast bool dropMulticast bool routines int - caPool *cert.NebulaCAPool disconnectInvalid bool closed atomic.Bool relayManager *relayManager @@ -152,15 +149,17 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { if c.Inside == nil { return nil, errors.New("no inside interface (tun)") } - if c.certState == nil { + if c.pki == nil { return nil, errors.New("no certificate state") } if c.Firewall == nil { return nil, errors.New("no firewall rules") } - myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP) + certificate := c.pki.GetCertState().Certificate + myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP) ifce := &Interface{ + pki: c.pki, hostMap: c.HostMap, outside: c.Outside, inside: c.Inside, @@ -170,14 +169,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { handshakeManager: c.HandshakeManager, createTime: time.Now(), lightHouse: c.lightHouse, - localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask), + localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask), dropLocalBroadcast: c.DropLocalBroadcast, dropMulticast: c.DropMulticast, routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), - caPool: c.caPool, disconnectInvalid: c.disconnectInvalid, myVpnIp: myVpnIp, relayManager: c.relayManager, @@ -198,7 +196,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) - ifce.certState.Store(c.certState) ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy) return ifce, nil @@ -295,8 +292,6 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { } func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { - c.RegisterReloadCallback(f.reloadCA) - c.RegisterReloadCallback(f.reloadCertKey) c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) c.RegisterReloadCallback(f.reloadMisc) @@ -305,40 +300,6 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { } } -func (f *Interface) reloadCA(c *config.C) { - // reload and check regardless - // todo: need mutex? - newCAs, err := loadCAFromConfig(f.l, c) - if err != nil { - f.l.WithError(err).Error("Could not refresh trusted CA certificates") - return - } - - f.caPool = newCAs - f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed") -} - -func (f *Interface) reloadCertKey(c *config.C) { - // reload and check in all cases - cs, err := NewCertStateFromConfig(c) - if err != nil { - f.l.WithError(err).Error("Could not refresh client cert") - return - } - - // did IP in cert change? if so, don't set - currentCert := f.certState.Load().certificate - oldIPs := currentCert.Details.Ips - newIPs := cs.certificate.Details.Ips - if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { - f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old") - return - } - - f.certState.Store(cs) - f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk") -} - func (f *Interface) reloadFirewall(c *config.C) { //TODO: need to trigger/detect if the certificate changed too if c.HasChanged("firewall") == false { @@ -346,7 +307,7 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.certState.Load().certificate, c) + fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return @@ -438,7 +399,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { f.firewall.EmitStats() f.handshakeManager.EmitStats() udpStats() - certExpirationGauge.Update(int64(f.certState.Load().certificate.Details.NotAfter.Sub(time.Now()) / time.Second)) + certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second)) } } } diff --git a/lighthouse.go b/lighthouse.go index 6c46663..9b3b837 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -132,7 +132,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, c.RegisterReloadCallback(func(c *config.C) { err := h.reload(c, false) switch v := err.(type) { - case util.ContextualError: + case *util.ContextualError: v.Log(l) case error: l.WithError(err).Error("failed to reload lighthouse") diff --git a/main.go b/main.go index d050db2..4e8448b 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package nebula import ( "context" "encoding/binary" - "errors" "fmt" "net" "time" @@ -46,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg err := configLogger(l, c) if err != nil { - return nil, util.NewContextualError("Failed to configure the logger", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err) } c.RegisterReloadCallback(func(c *config.C) { @@ -56,28 +55,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } }) - caPool, err := loadCAFromConfig(l, c) + pki, err := NewPKIFromConfig(l, c) if err != nil { - //The errors coming out of loadCA are already nicely formatted - return nil, util.NewContextualError("Failed to load ca from config", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } - l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") - cs, err := NewCertStateFromConfig(c) + certificate := pki.GetCertState().Certificate + fw, err := NewFirewallFromConfig(l, certificate, c) if err != nil { - //The errors coming out of NewCertStateFromConfig are already nicely formatted - return nil, util.NewContextualError("Failed to load certificate from config", nil, err) - } - l.WithField("cert", cs.certificate).Debug("Client nebula certificate") - - fw, err := NewFirewallFromConfig(l, cs.certificate, c) - if err != nil { - return nil, util.NewContextualError("Error while loading firewall rules", nil, err) + return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") // TODO: make sure mask is 4 bytes - tunCidr := cs.certificate.Details.Ips[0] + tunCidr := certificate.Details.Ips[0] ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) wireSSHReload(l, ssh, c) @@ -85,7 +76,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if c.GetBool("sshd.enabled", false) { sshStart, err = configSSH(l, ssh, c) if err != nil { - return nil, util.NewContextualError("Error while configuring the sshd", nil, err) + return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err) } } @@ -136,7 +127,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines) if err != nil { - return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } defer func() { @@ -160,7 +151,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } else { listenHost, err = net.ResolveIPAddr("ip", rawListenHost) if err != nil { - return nil, util.NewContextualError("Failed to resolve listen.host", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) } } @@ -182,7 +173,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg for _, rawPreferredRange := range rawPreferredRanges { _, preferredRange, err := net.ParseCIDR(rawPreferredRange) if err != nil { - return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err) } preferredRanges = append(preferredRanges, preferredRange) } @@ -195,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if rawLocalRange != "" { _, localRange, err := net.ParseCIDR(rawLocalRange) if err != nil { - return nil, util.NewContextualError("Failed to parse local_range", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err) } // Check if the entry for local_range was already specified in @@ -222,11 +213,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg punchy := NewPunchyFromConfig(l, c) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) - switch { - case errors.As(err, &util.ContextualError{}): - return nil, err - case err != nil: - return nil, util.NewContextualError("Failed to initialize lighthouse handler", nil, err) + if err != nil { + return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) } var messageMetrics *MessageMetrics @@ -266,7 +254,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg HostMap: hostMap, Inside: tun, Outside: udpConns[0], - certState: cs, + pki: pki, Cipher: c.GetString("cipher", "aes"), Firewall: fw, ServeDns: serveDns, @@ -282,7 +270,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg routines: routines, MessageMetrics: messageMetrics, version: buildVersion, - caPool: caPool, disconnectInvalid: c.GetBool("pki.disconnect_invalid", false), relayManager: NewRelayManager(ctx, l, hostMap, c), punchy: punchy, @@ -321,9 +308,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg // TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept // a context so that they can exit when the context is Done. statsStart, err := startStats(l, c, buildVersion, configTest) - if err != nil { - return nil, util.NewContextualError("Failed to start stats emitter", nil, err) + return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err) } if configTest { diff --git a/outside.go b/outside.go index 19a980b..a9dcdc8 100644 --- a/outside.go +++ b/outside.go @@ -404,7 +404,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache) + dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, out, q) if f.l.Level >= logrus.DebugLevel { diff --git a/pki.go b/pki.go new file mode 100644 index 0000000..91478ce --- /dev/null +++ b/pki.go @@ -0,0 +1,248 @@ +package nebula + +import ( + "errors" + "fmt" + "os" + "strings" + "sync/atomic" + "time" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" +) + +type PKI struct { + cs atomic.Pointer[CertState] + caPool atomic.Pointer[cert.NebulaCAPool] + l *logrus.Logger +} + +type CertState struct { + Certificate *cert.NebulaCertificate + RawCertificate []byte + RawCertificateNoKey []byte + PublicKey []byte + PrivateKey []byte +} + +func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { + pki := &PKI{l: l} + err := pki.reload(c, true) + if err != nil { + return nil, err + } + + c.RegisterReloadCallback(func(c *config.C) { + rErr := pki.reload(c, false) + if rErr != nil { + util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l) + } + }) + + return pki, nil +} + +func (p *PKI) GetCertState() *CertState { + return p.cs.Load() +} + +func (p *PKI) GetCAPool() *cert.NebulaCAPool { + return p.caPool.Load() +} + +func (p *PKI) reload(c *config.C, initial bool) error { + err := p.reloadCert(c, initial) + if err != nil { + if initial { + return err + } + err.Log(p.l) + } + + err = p.reloadCAPool(c) + if err != nil { + if initial { + return err + } + err.Log(p.l) + } + + return nil +} + +func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { + cs, err := newCertStateFromConfig(c) + if err != nil { + return util.NewContextualError("Could not load client cert", nil, err) + } + + if !initial { + // did IP in cert change? if so, don't set + currentCert := p.cs.Load().Certificate + oldIPs := currentCert.Details.Ips + newIPs := cs.Certificate.Details.Ips + if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { + return util.NewContextualError( + "IP in new cert was different from old", + m{"new_ip": newIPs[0], "old_ip": oldIPs[0]}, + nil, + ) + } + } + + p.cs.Store(cs) + if initial { + p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate") + } else { + p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk") + } + return nil +} + +func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { + caPool, err := loadCAPoolFromConfig(p.l, c) + if err != nil { + return util.NewContextualError("Failed to load ca from config", nil, err) + } + + p.caPool.Store(caPool) + p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") + return nil +} + +func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) { + // Marshal the certificate to ensure it is valid + rawCertificate, err := certificate.Marshal() + if err != nil { + return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) + } + + publicKey := certificate.Details.PublicKey + cs := &CertState{ + RawCertificate: rawCertificate, + Certificate: certificate, + PrivateKey: privateKey, + PublicKey: publicKey, + } + + cs.Certificate.Details.PublicKey = nil + rawCertNoKey, err := cs.Certificate.Marshal() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate no key: %s", err) + } + cs.RawCertificateNoKey = rawCertNoKey + // put public key back + cs.Certificate.Details.PublicKey = cs.PublicKey + return cs, nil +} + +func newCertStateFromConfig(c *config.C) (*CertState, error) { + var pemPrivateKey []byte + var err error + + privPathOrPEM := c.GetString("pki.key", "") + if privPathOrPEM == "" { + return nil, errors.New("no pki.key path or PEM data provided") + } + + if strings.Contains(privPathOrPEM, "-----BEGIN") { + pemPrivateKey = []byte(privPathOrPEM) + privPathOrPEM = "" + + } else { + pemPrivateKey, err = os.ReadFile(privPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) + } + } + + rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey) + if err != nil { + return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + + var rawCert []byte + + pubPathOrPEM := c.GetString("pki.cert", "") + if pubPathOrPEM == "" { + return nil, errors.New("no pki.cert path or PEM data provided") + } + + if strings.Contains(pubPathOrPEM, "-----BEGIN") { + rawCert = []byte(pubPathOrPEM) + pubPathOrPEM = "" + + } else { + rawCert, err = os.ReadFile(pubPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err) + } + } + + nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) + if err != nil { + return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) + } + + if nebulaCert.Expired(time.Now()) { + return nil, fmt.Errorf("nebula certificate for this host is expired") + } + + if len(nebulaCert.Details.Ips) == 0 { + return nil, fmt.Errorf("no IPs encoded in certificate") + } + + if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + + return newCertState(nebulaCert, rawKey) +} + +func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) { + var rawCA []byte + var err error + + caPathOrPEM := c.GetString("pki.ca", "") + if caPathOrPEM == "" { + return nil, errors.New("no pki.ca path or PEM data provided") + } + + if strings.Contains(caPathOrPEM, "-----BEGIN") { + rawCA = []byte(caPathOrPEM) + + } else { + rawCA, err = os.ReadFile(caPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) + } + } + + caPool, err := cert.NewCAPoolFromBytes(rawCA) + if errors.Is(err, cert.ErrExpired) { + var expired int + for _, crt := range caPool.CAs { + if crt.Expired(time.Now()) { + expired++ + l.WithField("cert", crt).Warn("expired certificate present in CA pool") + } + } + + if expired >= len(caPool.CAs) { + return nil, errors.New("no valid CA certificates present") + } + + } else if err != nil { + return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) + } + + for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) { + l.WithField("fingerprint", fp).Info("Blocklisting cert") + caPool.BlocklistFingerprint(fp) + } + + return caPool, nil +} diff --git a/ssh.go b/ssh.go index 0f624db..44286c8 100644 --- a/ssh.go +++ b/ssh.go @@ -754,7 +754,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return nil } - cert := ifce.certState.Load().certificate + cert := ifce.pki.GetCertState().Certificate if len(a) > 0 { parsedIp := net.ParseIP(a[0]) if parsedIp == nil { diff --git a/util/error.go b/util/error.go index 7f9bc47..a11c9c4 100644 --- a/util/error.go +++ b/util/error.go @@ -12,18 +12,38 @@ type ContextualError struct { Context string } -func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError { - return ContextualError{Context: msg, Fields: fields, RealError: realError} +func NewContextualError(msg string, fields map[string]interface{}, realError error) *ContextualError { + return &ContextualError{Context: msg, Fields: fields, RealError: realError} } -func (ce ContextualError) Error() string { +// ContextualizeIfNeeded is a helper function to turn an error into a ContextualError if it is not already one +func ContextualizeIfNeeded(msg string, err error) error { + switch err.(type) { + case *ContextualError: + return err + default: + return NewContextualError(msg, nil, err) + } +} + +// LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError +func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) { + switch v := err.(type) { + case *ContextualError: + v.Log(l) + default: + l.WithError(err).Error(msg) + } +} + +func (ce *ContextualError) Error() string { if ce.RealError == nil { return ce.Context } return ce.RealError.Error() } -func (ce ContextualError) Unwrap() error { +func (ce *ContextualError) Unwrap() error { if ce.RealError == nil { return errors.New(ce.Context) } diff --git a/util/error_test.go b/util/error_test.go index 747d04e..5041f82 100644 --- a/util/error_test.go +++ b/util/error_test.go @@ -2,6 +2,7 @@ package util import ( "errors" + "fmt" "testing" "github.com/sirupsen/logrus" @@ -67,3 +68,44 @@ func TestContextualError_Log(t *testing.T) { e.Log(l) assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs) } + +func TestLogWithContextIfNeeded(t *testing.T) { + l := logrus.New() + l.Formatter = &logrus.TextFormatter{ + DisableTimestamp: true, + DisableColors: true, + } + + tl := NewTestLogWriter() + l.Out = tl + + // Test ignoring fallback context + tl.Reset() + e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) + LogWithContextIfNeeded("This should get thrown away", e, l) + assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) + + // Test using fallback context + tl.Reset() + err := fmt.Errorf("this is a normal error") + LogWithContextIfNeeded("Fallback context woo", err, l) + assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs) +} + +func TestContextualizeIfNeeded(t *testing.T) { + // Test ignoring fallback context + e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) + assert.Same(t, e, ContextualizeIfNeeded("should be ignored", e)) + + // Test using fallback context + err := fmt.Errorf("this is a normal error") + cErr := ContextualizeIfNeeded("Fallback context woo", err) + + switch v := cErr.(type) { + case *ContextualError: + assert.Equal(t, err, v.RealError) + default: + t.Error("Error was not wrapped") + t.Fail() + } +}