mirror of https://github.com/slackhq/nebula.git
Teardown tunnel automatically if peer's certificate expired (#370)
This commit is contained in:
parent
e8b08e49e6
commit
32e2619323
|
@ -166,7 +166,23 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
|
|||
// Check for traffic coming back in from this host.
|
||||
traf := n.CheckIn(vpnIP)
|
||||
|
||||
// If we saw incoming packets from this ip, just return
|
||||
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
|
||||
if err != nil {
|
||||
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
||||
|
||||
if !n.intf.disconnectInvalid {
|
||||
n.ClearIP(vpnIP)
|
||||
n.ClearPendingDeletion(vpnIP)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if n.handleInvalidCertificate(now, vpnIP, hostinfo) {
|
||||
continue
|
||||
}
|
||||
|
||||
// If we saw an incoming packets from this ip and peer's certificate is not
|
||||
// expired, just ignore.
|
||||
if traf {
|
||||
if n.l.Level >= logrus.DebugLevel {
|
||||
n.l.WithField("vpnIp", IntIp(vpnIP)).
|
||||
|
@ -178,15 +194,6 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
|
|||
continue
|
||||
}
|
||||
|
||||
// If we didn't we may need to probe or destroy the conn
|
||||
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
|
||||
if err != nil {
|
||||
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
||||
n.ClearIP(vpnIP)
|
||||
n.ClearPendingDeletion(vpnIP)
|
||||
continue
|
||||
}
|
||||
|
||||
hostinfo.logger(n.l).
|
||||
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
||||
Debug("Tunnel status")
|
||||
|
@ -213,22 +220,31 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
|||
|
||||
vpnIP := ep.(uint32)
|
||||
|
||||
// If we saw incoming packets from this ip, just return
|
||||
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
|
||||
if err != nil {
|
||||
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
||||
|
||||
if !n.intf.disconnectInvalid {
|
||||
n.ClearIP(vpnIP)
|
||||
n.ClearPendingDeletion(vpnIP)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if n.handleInvalidCertificate(now, vpnIP, hostinfo) {
|
||||
continue
|
||||
}
|
||||
|
||||
// If we saw an incoming packets from this ip and peer's certificate is not
|
||||
// expired, just ignore.
|
||||
traf := n.CheckIn(vpnIP)
|
||||
if traf {
|
||||
n.l.WithField("vpnIp", IntIp(vpnIP)).
|
||||
WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
|
||||
Debug("Tunnel status")
|
||||
n.ClearIP(vpnIP)
|
||||
n.ClearPendingDeletion(vpnIP)
|
||||
continue
|
||||
}
|
||||
|
||||
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
|
||||
if err != nil {
|
||||
n.ClearIP(vpnIP)
|
||||
n.ClearPendingDeletion(vpnIP)
|
||||
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -256,3 +272,34 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
|
||||
func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32, hostinfo *HostInfo) bool {
|
||||
if !n.intf.disconnectInvalid {
|
||||
return false
|
||||
}
|
||||
|
||||
remoteCert := hostinfo.GetCert()
|
||||
if remoteCert == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
valid, err := remoteCert.Verify(now, n.intf.caPool)
|
||||
if valid {
|
||||
return false
|
||||
}
|
||||
|
||||
fingerprint, _ := remoteCert.Sha256Sum()
|
||||
n.l.WithField("vpnIp", IntIp(vpnIP)).WithError(err).
|
||||
WithField("certName", remoteCert.Details.Name).
|
||||
WithField("fingerprint", fingerprint).
|
||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||
|
||||
// Inform the remote and close the tunnel locally
|
||||
n.intf.sendCloseTunnel(hostinfo)
|
||||
n.intf.closeTunnel(hostinfo, false)
|
||||
|
||||
n.ClearIP(vpnIP)
|
||||
n.ClearPendingDeletion(vpnIP)
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -148,3 +150,96 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
|
||||
|
||||
}
|
||||
|
||||
// Check if we can disconnect the peer.
|
||||
// Validate if the peer's certificate is invalid (expired, etc.)
|
||||
// Disconnect only if disconnectInvalid: true is set.
|
||||
func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||
now := time.Now()
|
||||
l := NewTestLogger()
|
||||
ipNet := net.IPNet{
|
||||
IP: net.IPv4(172, 1, 1, 2),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
}
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||
|
||||
// Generate keys for CA and peer's cert.
|
||||
pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
|
||||
caCert := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "ca",
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(1 * time.Hour),
|
||||
IsCA: true,
|
||||
PublicKey: pubCA,
|
||||
},
|
||||
}
|
||||
caCert.Sign(privCA)
|
||||
ncp := &cert.NebulaCAPool{
|
||||
CAs: cert.NewCAPool().CAs,
|
||||
}
|
||||
ncp.CAs["ca"] = &caCert
|
||||
|
||||
pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
|
||||
peerCert := cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "host",
|
||||
Ips: []*net.IPNet{&ipNet},
|
||||
Subnets: []*net.IPNet{},
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(60 * time.Second),
|
||||
PublicKey: pubCrt,
|
||||
IsCA: false,
|
||||
Issuer: "ca",
|
||||
},
|
||||
}
|
||||
peerCert.Sign(privCA)
|
||||
|
||||
cs := &CertState{
|
||||
rawCertificate: []byte{},
|
||||
privateKey: []byte{},
|
||||
certificate: &cert.NebulaCertificate{},
|
||||
rawCertificateNoKey: []byte{},
|
||||
}
|
||||
|
||||
lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
||||
ifce := &Interface{
|
||||
hostMap: hostMap,
|
||||
inside: &Tun{},
|
||||
outside: &udpConn{},
|
||||
certState: cs,
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
||||
l: l,
|
||||
disconnectInvalid: true,
|
||||
caPool: ncp,
|
||||
}
|
||||
|
||||
// Create manager
|
||||
nc := newConnectionManager(l, ifce, 5, 10)
|
||||
ifce.connectionManager = nc
|
||||
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
|
||||
hostinfo.ConnectionState = &ConnectionState{
|
||||
certState: cs,
|
||||
peerCert: &peerCert,
|
||||
H: &noise.HandshakeState{},
|
||||
}
|
||||
|
||||
// Move ahead 45s.
|
||||
// Check if to disconnect with invalid certificate.
|
||||
// Should be alive.
|
||||
nextTick := now.Add(45 * time.Second)
|
||||
destroyed := nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo)
|
||||
assert.False(t, destroyed)
|
||||
|
||||
// Move ahead 61s.
|
||||
// Check if to disconnect with invalid certificate.
|
||||
// Should be disconnected.
|
||||
nextTick = now.Add(61 * time.Second)
|
||||
destroyed = nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo)
|
||||
assert.True(t, destroyed)
|
||||
}
|
||||
|
|
|
@ -7,9 +7,11 @@ pki:
|
|||
ca: /etc/nebula/ca.crt
|
||||
cert: /etc/nebula/host.crt
|
||||
key: /etc/nebula/host.key
|
||||
#blocklist is a list of certificate fingerprints that we will refuse to talk to
|
||||
# blocklist is a list of certificate fingerprints that we will refuse to talk to
|
||||
#blocklist:
|
||||
# - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72
|
||||
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
|
||||
#disconnect_invalid: false
|
||||
|
||||
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
||||
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
|
||||
|
|
|
@ -43,6 +43,7 @@ type InterfaceConfig struct {
|
|||
MessageMetrics *MessageMetrics
|
||||
version string
|
||||
caPool *cert.NebulaCAPool
|
||||
disconnectInvalid bool
|
||||
|
||||
ConntrackCacheTimeout time.Duration
|
||||
l *logrus.Logger
|
||||
|
@ -67,6 +68,7 @@ type Interface struct {
|
|||
udpBatchSize int
|
||||
routines int
|
||||
caPool *cert.NebulaCAPool
|
||||
disconnectInvalid bool
|
||||
|
||||
// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
|
||||
rebindCount int8
|
||||
|
@ -118,6 +120,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
|||
writers: make([]*udpConn, c.routines),
|
||||
readers: make([]io.ReadWriteCloser, c.routines),
|
||||
caPool: c.caPool,
|
||||
disconnectInvalid: c.disconnectInvalid,
|
||||
myVpnIp: ip2int(c.certState.certificate.Details.Ips[0].IP),
|
||||
|
||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||
|
|
1
main.go
1
main.go
|
@ -371,6 +371,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
MessageMetrics: messageMetrics,
|
||||
version: buildVersion,
|
||||
caPool: caPool,
|
||||
disconnectInvalid: config.GetBool("pki.disconnect_invalid", false),
|
||||
|
||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||
l: l,
|
||||
|
|
Loading…
Reference in New Issue