Remove tcp rtt tracking from the firewall (#1114)

This commit is contained in:
Nate Brown 2024-04-11 21:44:22 -05:00 committed by GitHub
parent 7efa750aef
commit c1711bc9c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 175 deletions

View File

@ -2,7 +2,6 @@ package nebula
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
@ -22,17 +21,12 @@ import (
"github.com/slackhq/nebula/firewall"
)
const tcpACK = 0x10
const tcpFIN = 0x01
type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
}
type conn struct {
Expires time.Time // Time when this conntrack entry will expire
Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
// record why the original connection passed the firewall, so we can re-validate
// after ruleset changes. Note, rulesVersion is a uint16 so that these two
@ -66,8 +60,6 @@ type Firewall struct {
rulesVersion uint16
defaultLocalCIDRAny bool
trackTCPRTT bool
metricTCPRTT metrics.Histogram
incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics
@ -183,7 +175,6 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
hasSubnets: len(c.Details.Subnets) > 0,
l: l,
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
incomingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
@ -422,9 +413,9 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(packet, fp, incoming, h, caPool, localCache) {
if f.inConns(fp, h, caPool, localCache) {
return nil
}
@ -462,7 +453,7 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
}
// We always want to conntrack since it is a faster operation
f.addConn(packet, fp, incoming)
f.addConn(fp, incoming)
return nil
}
@ -491,7 +482,7 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
}
func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
@ -551,11 +542,6 @@ func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *
switch fp.Protocol {
case firewall.ProtoTCP:
c.Expires = time.Now().Add(f.TCPTimeout)
if incoming {
f.checkTCPRTT(c, packet)
} else {
setTCPRTTTracking(c, packet)
}
case firewall.ProtoUDP:
c.Expires = time.Now().Add(f.UDPTimeout)
default:
@ -571,16 +557,13 @@ func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *
return true
}
func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) {
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
var timeout time.Duration
c := &conn{}
switch fp.Protocol {
case firewall.ProtoTCP:
timeout = f.TCPTimeout
if !incoming {
setTCPRTTTracking(c, packet)
}
case firewall.ProtoUDP:
timeout = f.UDPTimeout
default:
@ -1017,42 +1000,3 @@ func parsePort(s string) (startPort, endPort int32, err error) {
return
}
// TODO: write tests for these
func setTCPRTTTracking(c *conn, p []byte) {
if c.Seq != 0 {
return
}
ihl := int(p[0]&0x0f) << 2
// Don't track FIN packets
if p[ihl+13]&tcpFIN != 0 {
return
}
c.Seq = binary.BigEndian.Uint32(p[ihl+4 : ihl+8])
c.Sent = time.Now()
}
func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
if c.Seq == 0 {
return false
}
ihl := int(p[0]&0x0f) << 2
if p[ihl+13]&tcpACK == 0 {
return false
}
// Deal with wrap around, signed int cuts the ack window in half
// 0 is a bad ack, no data acknowledged
// positive number is a bad ack, ack is over half the window away
if int32(c.Seq-binary.BigEndian.Uint32(p[ihl+8:ihl+12])) >= 0 {
return false
}
f.metricTCPRTT.Update(time.Since(c.Sent).Nanoseconds())
c.Seq = 0
return true
}

View File

@ -2,14 +2,12 @@ package nebula
import (
"bytes"
"encoding/binary"
"errors"
"math"
"net"
"testing"
"time"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
@ -163,44 +161,44 @@ func TestFirewall_Drop(t *testing.T) {
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteIP
p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) {
@ -412,10 +410,10 @@ func TestFirewall_Drop2(t *testing.T) {
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule)
assert.Error(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
// c has the proper groups
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func TestFirewall_Drop3(t *testing.T) {
@ -495,13 +493,13 @@ func TestFirewall_Drop3(t *testing.T) {
cp := cert.NewCAPool()
// c1 should pass because host match
assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil))
assert.NoError(t, fw.Drop(p, true, &h1, cp, nil))
// c2 should pass because ca sha match
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil))
assert.NoError(t, fw.Drop(p, true, &h2, cp, nil))
// c3 should fail because no match
resetConntrack(fw)
assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_DropConntrackReload(t *testing.T) {
@ -545,12 +543,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
@ -559,7 +557,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
@ -568,7 +566,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
fw.rulesVersion = oldFw.rulesVersion + 1
// Drop outbound because conntrack doesn't match new ruleset
assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
}
func BenchmarkLookup(b *testing.B) {
@ -830,97 +828,6 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
}
func TestTCPRTTTracking(t *testing.T) {
b := make([]byte, 200)
// Max ip IHL (60 bytes) and tcp IHL (60 bytes)
b[0] = 15
b[60+12] = 15 << 4
f := Firewall{
metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)),
}
// Set SEQ to 1
binary.BigEndian.PutUint32(b[60+4:60+8], 1)
c := &conn{}
setTCPRTTTracking(c, b)
assert.Equal(t, uint32(1), c.Seq)
// Bad ack - no ack flag
binary.BigEndian.PutUint32(b[60+8:60+12], 80)
assert.False(t, f.checkTCPRTT(c, b))
// Bad ack, number is too low
binary.BigEndian.PutUint32(b[60+8:60+12], 0)
b[60+13] = uint8(0x10)
assert.False(t, f.checkTCPRTT(c, b))
// Good ack
binary.BigEndian.PutUint32(b[60+8:60+12], 80)
assert.True(t, f.checkTCPRTT(c, b))
assert.Equal(t, uint32(0), c.Seq)
// Set SEQ to 1
binary.BigEndian.PutUint32(b[60+4:60+8], 1)
c = &conn{}
setTCPRTTTracking(c, b)
assert.Equal(t, uint32(1), c.Seq)
// Good acks
binary.BigEndian.PutUint32(b[60+8:60+12], 81)
assert.True(t, f.checkTCPRTT(c, b))
assert.Equal(t, uint32(0), c.Seq)
// Set SEQ to max uint32 - 20
binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20)
c = &conn{}
setTCPRTTTracking(c, b)
assert.Equal(t, ^uint32(0)-20, c.Seq)
// Good acks
binary.BigEndian.PutUint32(b[60+8:60+12], 81)
assert.True(t, f.checkTCPRTT(c, b))
assert.Equal(t, uint32(0), c.Seq)
// Set SEQ to max uint32 / 2
binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2)
c = &conn{}
setTCPRTTTracking(c, b)
assert.Equal(t, ^uint32(0)/2, c.Seq)
// Below
binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1)
assert.False(t, f.checkTCPRTT(c, b))
assert.Equal(t, ^uint32(0)/2, c.Seq)
// Halfway below
binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0))
assert.False(t, f.checkTCPRTT(c, b))
assert.Equal(t, ^uint32(0)/2, c.Seq)
// Halfway above is ok
binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0))
assert.True(t, f.checkTCPRTT(c, b))
assert.Equal(t, uint32(0), c.Seq)
// Set SEQ to max uint32
binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0))
c = &conn{}
setTCPRTTTracking(c, b)
assert.Equal(t, ^uint32(0), c.Seq)
// Halfway + 1 above
binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1)
assert.False(t, f.checkTCPRTT(c, b))
assert.Equal(t, ^uint32(0), c.Seq)
// Halfway above
binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2)
assert.True(t, f.checkTCPRTT(c, b))
assert.Equal(t, uint32(0), c.Seq)
}
func TestFirewall_convertRule(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}

View File

@ -62,7 +62,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil {
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q)
@ -142,7 +142,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
}
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.pki.GetCAPool(), nil)
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
if dropReason != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp).

View File

@ -404,7 +404,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in