mirror of https://github.com/slackhq/nebula.git
Combine ca, cert, and key handling (#952)
This commit is contained in:
parent
223cc6e660
commit
5a131b2975
163
cert.go
163
cert.go
|
@ -1,163 +0,0 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
)
|
||||
|
||||
type CertState struct {
|
||||
certificate *cert.NebulaCertificate
|
||||
rawCertificate []byte
|
||||
rawCertificateNoKey []byte
|
||||
publicKey []byte
|
||||
privateKey []byte
|
||||
}
|
||||
|
||||
func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) {
|
||||
// Marshal the certificate to ensure it is valid
|
||||
rawCertificate, err := certificate.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
|
||||
}
|
||||
|
||||
publicKey := certificate.Details.PublicKey
|
||||
cs := &CertState{
|
||||
rawCertificate: rawCertificate,
|
||||
certificate: certificate, // PublicKey has been set to nil above
|
||||
privateKey: privateKey,
|
||||
publicKey: publicKey,
|
||||
}
|
||||
|
||||
cs.certificate.Details.PublicKey = nil
|
||||
rawCertNoKey, err := cs.certificate.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
|
||||
}
|
||||
cs.rawCertificateNoKey = rawCertNoKey
|
||||
// put public key back
|
||||
cs.certificate.Details.PublicKey = cs.publicKey
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
func NewCertStateFromConfig(c *config.C) (*CertState, error) {
|
||||
var pemPrivateKey []byte
|
||||
var err error
|
||||
|
||||
privPathOrPEM := c.GetString("pki.key", "")
|
||||
|
||||
if privPathOrPEM == "" {
|
||||
return nil, errors.New("no pki.key path or PEM data provided")
|
||||
}
|
||||
|
||||
if strings.Contains(privPathOrPEM, "-----BEGIN") {
|
||||
pemPrivateKey = []byte(privPathOrPEM)
|
||||
privPathOrPEM = "<inline>"
|
||||
} else {
|
||||
pemPrivateKey, err = ioutil.ReadFile(privPathOrPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
|
||||
}
|
||||
}
|
||||
|
||||
rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
|
||||
}
|
||||
|
||||
var rawCert []byte
|
||||
|
||||
pubPathOrPEM := c.GetString("pki.cert", "")
|
||||
|
||||
if pubPathOrPEM == "" {
|
||||
return nil, errors.New("no pki.cert path or PEM data provided")
|
||||
}
|
||||
|
||||
if strings.Contains(pubPathOrPEM, "-----BEGIN") {
|
||||
rawCert = []byte(pubPathOrPEM)
|
||||
pubPathOrPEM = "<inline>"
|
||||
} else {
|
||||
rawCert, err = ioutil.ReadFile(pubPathOrPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err)
|
||||
}
|
||||
}
|
||||
|
||||
nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
|
||||
}
|
||||
|
||||
if nebulaCert.Expired(time.Now()) {
|
||||
return nil, fmt.Errorf("nebula certificate for this host is expired")
|
||||
}
|
||||
|
||||
if len(nebulaCert.Details.Ips) == 0 {
|
||||
return nil, fmt.Errorf("no IPs encoded in certificate")
|
||||
}
|
||||
|
||||
if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
|
||||
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
|
||||
}
|
||||
|
||||
return NewCertState(nebulaCert, rawKey)
|
||||
}
|
||||
|
||||
func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
|
||||
var rawCA []byte
|
||||
var err error
|
||||
|
||||
caPathOrPEM := c.GetString("pki.ca", "")
|
||||
if caPathOrPEM == "" {
|
||||
return nil, errors.New("no pki.ca path or PEM data provided")
|
||||
}
|
||||
|
||||
if strings.Contains(caPathOrPEM, "-----BEGIN") {
|
||||
rawCA = []byte(caPathOrPEM)
|
||||
|
||||
} else {
|
||||
rawCA, err = ioutil.ReadFile(caPathOrPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
|
||||
}
|
||||
}
|
||||
|
||||
CAs, err := cert.NewCAPoolFromBytes(rawCA)
|
||||
if errors.Is(err, cert.ErrExpired) {
|
||||
var expired int
|
||||
for _, cert := range CAs.CAs {
|
||||
if cert.Expired(time.Now()) {
|
||||
expired++
|
||||
l.WithField("cert", cert).Warn("expired certificate present in CA pool")
|
||||
}
|
||||
}
|
||||
|
||||
if expired >= len(CAs.CAs) {
|
||||
return nil, errors.New("no valid CA certificates present")
|
||||
}
|
||||
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
||||
}
|
||||
|
||||
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
|
||||
l.WithField("fingerprint", fp).Info("Blocklisting cert")
|
||||
CAs.BlocklistFingerprint(fp)
|
||||
}
|
||||
|
||||
// Support deprecated config for at least one minor release to allow for migrations
|
||||
//TODO: remove in 2022 or later
|
||||
for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) {
|
||||
l.WithField("fingerprint", fp).Info("Blocklisting cert")
|
||||
l.Warn("pki.blacklist is deprecated and will not be supported in a future release. Please migrate your config to use pki.blocklist")
|
||||
CAs.BlocklistFingerprint(fp)
|
||||
}
|
||||
|
||||
return CAs, nil
|
||||
}
|
|
@ -59,13 +59,8 @@ func main() {
|
|||
}
|
||||
|
||||
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||
|
||||
switch v := err.(type) {
|
||||
case util.ContextualError:
|
||||
v.Log(l)
|
||||
os.Exit(1)
|
||||
case error:
|
||||
l.WithError(err).Error("Failed to start")
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Failed to start", err, l)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
|
|
|
@ -53,13 +53,8 @@ func main() {
|
|||
}
|
||||
|
||||
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||
|
||||
switch v := err.(type) {
|
||||
case util.ContextualError:
|
||||
v.Log(l)
|
||||
os.Exit(1)
|
||||
case error:
|
||||
l.WithError(err).Error("Failed to start")
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Failed to start", err, l)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
|
|
|
@ -405,8 +405,8 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
certState := n.intf.certState.Load()
|
||||
return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature)
|
||||
certState := n.intf.pki.GetCertState()
|
||||
return bytes.Equal(current.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature)
|
||||
}
|
||||
|
||||
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
|
||||
|
@ -427,7 +427,7 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
|
|||
return false
|
||||
}
|
||||
|
||||
valid, err := remoteCert.VerifyWithCache(now, n.intf.caPool)
|
||||
valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool())
|
||||
if valid {
|
||||
return false
|
||||
}
|
||||
|
@ -464,8 +464,8 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
|
|||
}
|
||||
|
||||
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
certState := n.intf.certState.Load()
|
||||
if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) {
|
||||
certState := n.intf.pki.GetCertState()
|
||||
if bytes.Equal(hostinfo.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -44,10 +44,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||
// Very incomplete mock objects
|
||||
hostMap := NewHostMap(l, vpncidr, preferredRanges)
|
||||
cs := &CertState{
|
||||
rawCertificate: []byte{},
|
||||
privateKey: []byte{},
|
||||
certificate: &cert.NebulaCertificate{},
|
||||
rawCertificateNoKey: []byte{},
|
||||
RawCertificate: []byte{},
|
||||
PrivateKey: []byte{},
|
||||
Certificate: &cert.NebulaCertificate{},
|
||||
RawCertificateNoKey: []byte{},
|
||||
}
|
||||
|
||||
lh := newTestLighthouse()
|
||||
|
@ -57,10 +57,11 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||
outside: &udp.NoopConn{},
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
pki: &PKI{},
|
||||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
|
||||
l: l,
|
||||
}
|
||||
ifce.certState.Store(cs)
|
||||
ifce.pki.cs.Store(cs)
|
||||
|
||||
// Create manager
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -123,10 +124,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||
// Very incomplete mock objects
|
||||
hostMap := NewHostMap(l, vpncidr, preferredRanges)
|
||||
cs := &CertState{
|
||||
rawCertificate: []byte{},
|
||||
privateKey: []byte{},
|
||||
certificate: &cert.NebulaCertificate{},
|
||||
rawCertificateNoKey: []byte{},
|
||||
RawCertificate: []byte{},
|
||||
PrivateKey: []byte{},
|
||||
Certificate: &cert.NebulaCertificate{},
|
||||
RawCertificateNoKey: []byte{},
|
||||
}
|
||||
|
||||
lh := newTestLighthouse()
|
||||
|
@ -136,10 +137,11 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||
outside: &udp.NoopConn{},
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
pki: &PKI{},
|
||||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
|
||||
l: l,
|
||||
}
|
||||
ifce.certState.Store(cs)
|
||||
ifce.pki.cs.Store(cs)
|
||||
|
||||
// Create manager
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -242,10 +244,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||
peerCert.Sign(cert.Curve_CURVE25519, privCA)
|
||||
|
||||
cs := &CertState{
|
||||
rawCertificate: []byte{},
|
||||
privateKey: []byte{},
|
||||
certificate: &cert.NebulaCertificate{},
|
||||
rawCertificateNoKey: []byte{},
|
||||
RawCertificate: []byte{},
|
||||
PrivateKey: []byte{},
|
||||
Certificate: &cert.NebulaCertificate{},
|
||||
RawCertificateNoKey: []byte{},
|
||||
}
|
||||
|
||||
lh := newTestLighthouse()
|
||||
|
@ -258,9 +260,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
|
||||
l: l,
|
||||
disconnectInvalid: true,
|
||||
caPool: ncp,
|
||||
pki: &PKI{},
|
||||
}
|
||||
ifce.certState.Store(cs)
|
||||
ifce.pki.cs.Store(cs)
|
||||
ifce.pki.caPool.Store(ncp)
|
||||
|
||||
// Create manager
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
|
|
@ -30,15 +30,15 @@ type ConnectionState struct {
|
|||
|
||||
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
||||
var dhFunc noise.DHFunc
|
||||
curCertState := f.certState.Load()
|
||||
curCertState := f.pki.GetCertState()
|
||||
|
||||
switch curCertState.certificate.Details.Curve {
|
||||
switch curCertState.Certificate.Details.Curve {
|
||||
case cert.Curve_CURVE25519:
|
||||
dhFunc = noise.DH25519
|
||||
case cert.Curve_P256:
|
||||
dhFunc = noiseutil.DHP256
|
||||
default:
|
||||
l.Errorf("invalid curve: %s", curCertState.certificate.Details.Curve)
|
||||
l.Errorf("invalid curve: %s", curCertState.Certificate.Details.Curve)
|
||||
return nil
|
||||
}
|
||||
cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
|
||||
|
@ -46,7 +46,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
|
|||
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
}
|
||||
|
||||
static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey}
|
||||
static := noise.DHKey{Private: curCertState.PrivateKey, Public: curCertState.PublicKey}
|
||||
|
||||
b := NewBits(ReplayWindow)
|
||||
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
|
||||
|
|
|
@ -161,7 +161,7 @@ func (c *Control) GetHostmap() *HostMap {
|
|||
}
|
||||
|
||||
func (c *Control) GetCert() *cert.NebulaCertificate {
|
||||
return c.f.certState.Load().certificate
|
||||
return c.f.pki.GetCertState().Certificate
|
||||
}
|
||||
|
||||
func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
|
||||
|
|
|
@ -33,7 +33,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
|
|||
hsProto := &NebulaHandshakeDetails{
|
||||
InitiatorIndex: hostinfo.localIndexId,
|
||||
Time: uint64(time.Now().UnixNano()),
|
||||
Cert: ci.certState.rawCertificateNoKey,
|
||||
Cert: ci.certState.RawCertificateNoKey,
|
||||
}
|
||||
|
||||
hsBytes := []byte{}
|
||||
|
@ -91,7 +91,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
|
|||
return
|
||||
}
|
||||
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||
|
@ -155,7 +155,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
|
|||
Info("Handshake message received")
|
||||
|
||||
hs.Details.ResponderIndex = myIndex
|
||||
hs.Details.Cert = ci.certState.rawCertificateNoKey
|
||||
hs.Details.Cert = ci.certState.RawCertificateNoKey
|
||||
// Update the time in case their clock is way off from ours
|
||||
hs.Details.Time = uint64(time.Now().UnixNano())
|
||||
|
||||
|
@ -399,7 +399,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
|
|||
return true
|
||||
}
|
||||
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
||||
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
|
|
|
@ -69,7 +69,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||
ci.queueLock.Unlock()
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
|
||||
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason == nil {
|
||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, packet, nb, out, q)
|
||||
|
||||
|
@ -183,7 +183,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||
}
|
||||
|
||||
// check if packet is in outbound fw rules
|
||||
dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.caPool, nil)
|
||||
dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||
if dropReason != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("fwPacket", fp).
|
||||
|
|
57
interface.go
57
interface.go
|
@ -13,7 +13,6 @@ import (
|
|||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
|
@ -28,7 +27,7 @@ type InterfaceConfig struct {
|
|||
HostMap *HostMap
|
||||
Outside udp.Conn
|
||||
Inside overlay.Device
|
||||
certState *CertState
|
||||
pki *PKI
|
||||
Cipher string
|
||||
Firewall *Firewall
|
||||
ServeDns bool
|
||||
|
@ -41,7 +40,6 @@ type InterfaceConfig struct {
|
|||
routines int
|
||||
MessageMetrics *MessageMetrics
|
||||
version string
|
||||
caPool *cert.NebulaCAPool
|
||||
disconnectInvalid bool
|
||||
relayManager *relayManager
|
||||
punchy *Punchy
|
||||
|
@ -58,7 +56,7 @@ type Interface struct {
|
|||
hostMap *HostMap
|
||||
outside udp.Conn
|
||||
inside overlay.Device
|
||||
certState atomic.Pointer[CertState]
|
||||
pki *PKI
|
||||
cipher string
|
||||
firewall *Firewall
|
||||
connectionManager *connectionManager
|
||||
|
@ -71,7 +69,6 @@ type Interface struct {
|
|||
dropLocalBroadcast bool
|
||||
dropMulticast bool
|
||||
routines int
|
||||
caPool *cert.NebulaCAPool
|
||||
disconnectInvalid bool
|
||||
closed atomic.Bool
|
||||
relayManager *relayManager
|
||||
|
@ -152,15 +149,17 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||
if c.Inside == nil {
|
||||
return nil, errors.New("no inside interface (tun)")
|
||||
}
|
||||
if c.certState == nil {
|
||||
if c.pki == nil {
|
||||
return nil, errors.New("no certificate state")
|
||||
}
|
||||
if c.Firewall == nil {
|
||||
return nil, errors.New("no firewall rules")
|
||||
}
|
||||
|
||||
myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP)
|
||||
certificate := c.pki.GetCertState().Certificate
|
||||
myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP)
|
||||
ifce := &Interface{
|
||||
pki: c.pki,
|
||||
hostMap: c.HostMap,
|
||||
outside: c.Outside,
|
||||
inside: c.Inside,
|
||||
|
@ -170,14 +169,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||
handshakeManager: c.HandshakeManager,
|
||||
createTime: time.Now(),
|
||||
lightHouse: c.lightHouse,
|
||||
localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask),
|
||||
localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask),
|
||||
dropLocalBroadcast: c.DropLocalBroadcast,
|
||||
dropMulticast: c.DropMulticast,
|
||||
routines: c.routines,
|
||||
version: c.version,
|
||||
writers: make([]udp.Conn, c.routines),
|
||||
readers: make([]io.ReadWriteCloser, c.routines),
|
||||
caPool: c.caPool,
|
||||
disconnectInvalid: c.disconnectInvalid,
|
||||
myVpnIp: myVpnIp,
|
||||
relayManager: c.relayManager,
|
||||
|
@ -198,7 +196,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||
|
||||
ifce.certState.Store(c.certState)
|
||||
ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
|
||||
|
||||
return ifce, nil
|
||||
|
@ -295,8 +292,6 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
|||
}
|
||||
|
||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||
c.RegisterReloadCallback(f.reloadCA)
|
||||
c.RegisterReloadCallback(f.reloadCertKey)
|
||||
c.RegisterReloadCallback(f.reloadFirewall)
|
||||
c.RegisterReloadCallback(f.reloadSendRecvError)
|
||||
c.RegisterReloadCallback(f.reloadMisc)
|
||||
|
@ -305,40 +300,6 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
|||
}
|
||||
}
|
||||
|
||||
func (f *Interface) reloadCA(c *config.C) {
|
||||
// reload and check regardless
|
||||
// todo: need mutex?
|
||||
newCAs, err := loadCAFromConfig(f.l, c)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Could not refresh trusted CA certificates")
|
||||
return
|
||||
}
|
||||
|
||||
f.caPool = newCAs
|
||||
f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
|
||||
}
|
||||
|
||||
func (f *Interface) reloadCertKey(c *config.C) {
|
||||
// reload and check in all cases
|
||||
cs, err := NewCertStateFromConfig(c)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Could not refresh client cert")
|
||||
return
|
||||
}
|
||||
|
||||
// did IP in cert change? if so, don't set
|
||||
currentCert := f.certState.Load().certificate
|
||||
oldIPs := currentCert.Details.Ips
|
||||
newIPs := cs.certificate.Details.Ips
|
||||
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
|
||||
f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
|
||||
return
|
||||
}
|
||||
|
||||
f.certState.Store(cs)
|
||||
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
|
||||
}
|
||||
|
||||
func (f *Interface) reloadFirewall(c *config.C) {
|
||||
//TODO: need to trigger/detect if the certificate changed too
|
||||
if c.HasChanged("firewall") == false {
|
||||
|
@ -346,7 +307,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||
return
|
||||
}
|
||||
|
||||
fw, err := NewFirewallFromConfig(f.l, f.certState.Load().certificate, c)
|
||||
fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Error while creating firewall during reload")
|
||||
return
|
||||
|
@ -438,7 +399,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
|||
f.firewall.EmitStats()
|
||||
f.handshakeManager.EmitStats()
|
||||
udpStats()
|
||||
certExpirationGauge.Update(int64(f.certState.Load().certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
|
||||
certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -132,7 +132,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
|
|||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := h.reload(c, false)
|
||||
switch v := err.(type) {
|
||||
case util.ContextualError:
|
||||
case *util.ContextualError:
|
||||
v.Log(l)
|
||||
case error:
|
||||
l.WithError(err).Error("failed to reload lighthouse")
|
||||
|
|
46
main.go
46
main.go
|
@ -3,7 +3,6 @@ package nebula
|
|||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
@ -46,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
|
||||
err := configLogger(l, c)
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Failed to configure the logger", nil, err)
|
||||
return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err)
|
||||
}
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
|
@ -56,28 +55,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
}
|
||||
})
|
||||
|
||||
caPool, err := loadCAFromConfig(l, c)
|
||||
pki, err := NewPKIFromConfig(l, c)
|
||||
if err != nil {
|
||||
//The errors coming out of loadCA are already nicely formatted
|
||||
return nil, util.NewContextualError("Failed to load ca from config", nil, err)
|
||||
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
|
||||
}
|
||||
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||
|
||||
cs, err := NewCertStateFromConfig(c)
|
||||
certificate := pki.GetCertState().Certificate
|
||||
fw, err := NewFirewallFromConfig(l, certificate, c)
|
||||
if err != nil {
|
||||
//The errors coming out of NewCertStateFromConfig are already nicely formatted
|
||||
return nil, util.NewContextualError("Failed to load certificate from config", nil, err)
|
||||
}
|
||||
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
||||
|
||||
fw, err := NewFirewallFromConfig(l, cs.certificate, c)
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Error while loading firewall rules", nil, err)
|
||||
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
|
||||
}
|
||||
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
|
||||
|
||||
// TODO: make sure mask is 4 bytes
|
||||
tunCidr := cs.certificate.Details.Ips[0]
|
||||
tunCidr := certificate.Details.Ips[0]
|
||||
|
||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||
wireSSHReload(l, ssh, c)
|
||||
|
@ -85,7 +76,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
if c.GetBool("sshd.enabled", false) {
|
||||
sshStart, err = configSSH(l, ssh, c)
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Error while configuring the sshd", nil, err)
|
||||
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -136,7 +127,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
|
||||
tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines)
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err)
|
||||
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
|
@ -160,7 +151,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
} else {
|
||||
listenHost, err = net.ResolveIPAddr("ip", rawListenHost)
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Failed to resolve listen.host", nil, err)
|
||||
return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -182,7 +173,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
for _, rawPreferredRange := range rawPreferredRanges {
|
||||
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err)
|
||||
return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err)
|
||||
}
|
||||
preferredRanges = append(preferredRanges, preferredRange)
|
||||
}
|
||||
|
@ -195,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
if rawLocalRange != "" {
|
||||
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Failed to parse local_range", nil, err)
|
||||
return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err)
|
||||
}
|
||||
|
||||
// Check if the entry for local_range was already specified in
|
||||
|
@ -222,11 +213,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
|
||||
punchy := NewPunchyFromConfig(l, c)
|
||||
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
|
||||
switch {
|
||||
case errors.As(err, &util.ContextualError{}):
|
||||
return nil, err
|
||||
case err != nil:
|
||||
return nil, util.NewContextualError("Failed to initialize lighthouse handler", nil, err)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
|
||||
}
|
||||
|
||||
var messageMetrics *MessageMetrics
|
||||
|
@ -266,7 +254,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
HostMap: hostMap,
|
||||
Inside: tun,
|
||||
Outside: udpConns[0],
|
||||
certState: cs,
|
||||
pki: pki,
|
||||
Cipher: c.GetString("cipher", "aes"),
|
||||
Firewall: fw,
|
||||
ServeDns: serveDns,
|
||||
|
@ -282,7 +270,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
routines: routines,
|
||||
MessageMetrics: messageMetrics,
|
||||
version: buildVersion,
|
||||
caPool: caPool,
|
||||
disconnectInvalid: c.GetBool("pki.disconnect_invalid", false),
|
||||
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
||||
punchy: punchy,
|
||||
|
@ -321,9 +308,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
|
||||
// a context so that they can exit when the context is Done.
|
||||
statsStart, err := startStats(l, c, buildVersion, configTest)
|
||||
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Failed to start stats emitter", nil, err)
|
||||
return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
|
||||
}
|
||||
|
||||
if configTest {
|
||||
|
|
|
@ -404,7 +404,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||
return false
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache)
|
||||
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason != nil {
|
||||
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
|
|
|
@ -0,0 +1,248 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
type PKI struct {
|
||||
cs atomic.Pointer[CertState]
|
||||
caPool atomic.Pointer[cert.NebulaCAPool]
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type CertState struct {
|
||||
Certificate *cert.NebulaCertificate
|
||||
RawCertificate []byte
|
||||
RawCertificateNoKey []byte
|
||||
PublicKey []byte
|
||||
PrivateKey []byte
|
||||
}
|
||||
|
||||
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
|
||||
pki := &PKI{l: l}
|
||||
err := pki.reload(c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
rErr := pki.reload(c, false)
|
||||
if rErr != nil {
|
||||
util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l)
|
||||
}
|
||||
})
|
||||
|
||||
return pki, nil
|
||||
}
|
||||
|
||||
func (p *PKI) GetCertState() *CertState {
|
||||
return p.cs.Load()
|
||||
}
|
||||
|
||||
func (p *PKI) GetCAPool() *cert.NebulaCAPool {
|
||||
return p.caPool.Load()
|
||||
}
|
||||
|
||||
func (p *PKI) reload(c *config.C, initial bool) error {
|
||||
err := p.reloadCert(c, initial)
|
||||
if err != nil {
|
||||
if initial {
|
||||
return err
|
||||
}
|
||||
err.Log(p.l)
|
||||
}
|
||||
|
||||
err = p.reloadCAPool(c)
|
||||
if err != nil {
|
||||
if initial {
|
||||
return err
|
||||
}
|
||||
err.Log(p.l)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
|
||||
cs, err := newCertStateFromConfig(c)
|
||||
if err != nil {
|
||||
return util.NewContextualError("Could not load client cert", nil, err)
|
||||
}
|
||||
|
||||
if !initial {
|
||||
// did IP in cert change? if so, don't set
|
||||
currentCert := p.cs.Load().Certificate
|
||||
oldIPs := currentCert.Details.Ips
|
||||
newIPs := cs.Certificate.Details.Ips
|
||||
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
|
||||
return util.NewContextualError(
|
||||
"IP in new cert was different from old",
|
||||
m{"new_ip": newIPs[0], "old_ip": oldIPs[0]},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
p.cs.Store(cs)
|
||||
if initial {
|
||||
p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate")
|
||||
} else {
|
||||
p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
|
||||
caPool, err := loadCAPoolFromConfig(p.l, c)
|
||||
if err != nil {
|
||||
return util.NewContextualError("Failed to load ca from config", nil, err)
|
||||
}
|
||||
|
||||
p.caPool.Store(caPool)
|
||||
p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||
return nil
|
||||
}
|
||||
|
||||
func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) {
|
||||
// Marshal the certificate to ensure it is valid
|
||||
rawCertificate, err := certificate.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
|
||||
}
|
||||
|
||||
publicKey := certificate.Details.PublicKey
|
||||
cs := &CertState{
|
||||
RawCertificate: rawCertificate,
|
||||
Certificate: certificate,
|
||||
PrivateKey: privateKey,
|
||||
PublicKey: publicKey,
|
||||
}
|
||||
|
||||
cs.Certificate.Details.PublicKey = nil
|
||||
rawCertNoKey, err := cs.Certificate.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
|
||||
}
|
||||
cs.RawCertificateNoKey = rawCertNoKey
|
||||
// put public key back
|
||||
cs.Certificate.Details.PublicKey = cs.PublicKey
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
||||
var pemPrivateKey []byte
|
||||
var err error
|
||||
|
||||
privPathOrPEM := c.GetString("pki.key", "")
|
||||
if privPathOrPEM == "" {
|
||||
return nil, errors.New("no pki.key path or PEM data provided")
|
||||
}
|
||||
|
||||
if strings.Contains(privPathOrPEM, "-----BEGIN") {
|
||||
pemPrivateKey = []byte(privPathOrPEM)
|
||||
privPathOrPEM = "<inline>"
|
||||
|
||||
} else {
|
||||
pemPrivateKey, err = os.ReadFile(privPathOrPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
|
||||
}
|
||||
}
|
||||
|
||||
rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
|
||||
}
|
||||
|
||||
var rawCert []byte
|
||||
|
||||
pubPathOrPEM := c.GetString("pki.cert", "")
|
||||
if pubPathOrPEM == "" {
|
||||
return nil, errors.New("no pki.cert path or PEM data provided")
|
||||
}
|
||||
|
||||
if strings.Contains(pubPathOrPEM, "-----BEGIN") {
|
||||
rawCert = []byte(pubPathOrPEM)
|
||||
pubPathOrPEM = "<inline>"
|
||||
|
||||
} else {
|
||||
rawCert, err = os.ReadFile(pubPathOrPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err)
|
||||
}
|
||||
}
|
||||
|
||||
nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
|
||||
}
|
||||
|
||||
if nebulaCert.Expired(time.Now()) {
|
||||
return nil, fmt.Errorf("nebula certificate for this host is expired")
|
||||
}
|
||||
|
||||
if len(nebulaCert.Details.Ips) == 0 {
|
||||
return nil, fmt.Errorf("no IPs encoded in certificate")
|
||||
}
|
||||
|
||||
if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
|
||||
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
|
||||
}
|
||||
|
||||
return newCertState(nebulaCert, rawKey)
|
||||
}
|
||||
|
||||
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
|
||||
var rawCA []byte
|
||||
var err error
|
||||
|
||||
caPathOrPEM := c.GetString("pki.ca", "")
|
||||
if caPathOrPEM == "" {
|
||||
return nil, errors.New("no pki.ca path or PEM data provided")
|
||||
}
|
||||
|
||||
if strings.Contains(caPathOrPEM, "-----BEGIN") {
|
||||
rawCA = []byte(caPathOrPEM)
|
||||
|
||||
} else {
|
||||
rawCA, err = os.ReadFile(caPathOrPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
|
||||
}
|
||||
}
|
||||
|
||||
caPool, err := cert.NewCAPoolFromBytes(rawCA)
|
||||
if errors.Is(err, cert.ErrExpired) {
|
||||
var expired int
|
||||
for _, crt := range caPool.CAs {
|
||||
if crt.Expired(time.Now()) {
|
||||
expired++
|
||||
l.WithField("cert", crt).Warn("expired certificate present in CA pool")
|
||||
}
|
||||
}
|
||||
|
||||
if expired >= len(caPool.CAs) {
|
||||
return nil, errors.New("no valid CA certificates present")
|
||||
}
|
||||
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
|
||||
}
|
||||
|
||||
for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
|
||||
l.WithField("fingerprint", fp).Info("Blocklisting cert")
|
||||
caPool.BlocklistFingerprint(fp)
|
||||
}
|
||||
|
||||
return caPool, nil
|
||||
}
|
2
ssh.go
2
ssh.go
|
@ -754,7 +754,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
|
|||
return nil
|
||||
}
|
||||
|
||||
cert := ifce.certState.Load().certificate
|
||||
cert := ifce.pki.GetCertState().Certificate
|
||||
if len(a) > 0 {
|
||||
parsedIp := net.ParseIP(a[0])
|
||||
if parsedIp == nil {
|
||||
|
|
|
@ -12,18 +12,38 @@ type ContextualError struct {
|
|||
Context string
|
||||
}
|
||||
|
||||
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
|
||||
return ContextualError{Context: msg, Fields: fields, RealError: realError}
|
||||
func NewContextualError(msg string, fields map[string]interface{}, realError error) *ContextualError {
|
||||
return &ContextualError{Context: msg, Fields: fields, RealError: realError}
|
||||
}
|
||||
|
||||
func (ce ContextualError) Error() string {
|
||||
// ContextualizeIfNeeded is a helper function to turn an error into a ContextualError if it is not already one
|
||||
func ContextualizeIfNeeded(msg string, err error) error {
|
||||
switch err.(type) {
|
||||
case *ContextualError:
|
||||
return err
|
||||
default:
|
||||
return NewContextualError(msg, nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError
|
||||
func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) {
|
||||
switch v := err.(type) {
|
||||
case *ContextualError:
|
||||
v.Log(l)
|
||||
default:
|
||||
l.WithError(err).Error(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (ce *ContextualError) Error() string {
|
||||
if ce.RealError == nil {
|
||||
return ce.Context
|
||||
}
|
||||
return ce.RealError.Error()
|
||||
}
|
||||
|
||||
func (ce ContextualError) Unwrap() error {
|
||||
func (ce *ContextualError) Unwrap() error {
|
||||
if ce.RealError == nil {
|
||||
return errors.New(ce.Context)
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package util
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -67,3 +68,44 @@ func TestContextualError_Log(t *testing.T) {
|
|||
e.Log(l)
|
||||
assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
|
||||
}
|
||||
|
||||
func TestLogWithContextIfNeeded(t *testing.T) {
|
||||
l := logrus.New()
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
DisableTimestamp: true,
|
||||
DisableColors: true,
|
||||
}
|
||||
|
||||
tl := NewTestLogWriter()
|
||||
l.Out = tl
|
||||
|
||||
// Test ignoring fallback context
|
||||
tl.Reset()
|
||||
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
|
||||
LogWithContextIfNeeded("This should get thrown away", e, l)
|
||||
assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
|
||||
|
||||
// Test using fallback context
|
||||
tl.Reset()
|
||||
err := fmt.Errorf("this is a normal error")
|
||||
LogWithContextIfNeeded("Fallback context woo", err, l)
|
||||
assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs)
|
||||
}
|
||||
|
||||
func TestContextualizeIfNeeded(t *testing.T) {
|
||||
// Test ignoring fallback context
|
||||
e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
|
||||
assert.Same(t, e, ContextualizeIfNeeded("should be ignored", e))
|
||||
|
||||
// Test using fallback context
|
||||
err := fmt.Errorf("this is a normal error")
|
||||
cErr := ContextualizeIfNeeded("Fallback context woo", err)
|
||||
|
||||
switch v := cErr.(type) {
|
||||
case *ContextualError:
|
||||
assert.Equal(t, err, v.RealError)
|
||||
default:
|
||||
t.Error("Error was not wrapped")
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue