mirror of https://github.com/slackhq/nebula.git
Fix ca* checks
This commit is contained in:
parent
8e6b72516b
commit
56657065e0
134
firewall.go
134
firewall.go
|
@ -83,19 +83,23 @@ func newFirewallTable() *FirewallTable {
|
|||
}
|
||||
}
|
||||
|
||||
type FirewallCA struct {
|
||||
Any *FirewallRule
|
||||
CANames map[string]*FirewallRule
|
||||
CAShas map[string]*FirewallRule
|
||||
}
|
||||
|
||||
type FirewallRule struct {
|
||||
// Any makes Hosts, Groups, and CIDR irrelevant. CAName and CASha still need to be checked
|
||||
Any bool
|
||||
Hosts map[string]struct{}
|
||||
Groups [][]string
|
||||
CIDR *CIDRTree
|
||||
CANames map[string]struct{}
|
||||
CAShas map[string]struct{}
|
||||
// Any makes Hosts, Groups, and CIDR irrelevant
|
||||
Any bool
|
||||
Hosts map[string]struct{}
|
||||
Groups [][]string
|
||||
CIDR *CIDRTree
|
||||
}
|
||||
|
||||
// Even though ports are uint16, int32 maps are faster for lookup
|
||||
// Plus we can use `-1` for fragment rules
|
||||
type firewallPort map[int32]*FirewallRule
|
||||
type firewallPort map[int32]*FirewallCA
|
||||
|
||||
type FirewallPacket struct {
|
||||
LocalIP uint32
|
||||
|
@ -182,9 +186,9 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
|
|||
|
||||
func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
|
||||
fw := NewFirewall(
|
||||
c.GetDuration("firewall.conntrack.tcp_timeout", time.Duration(time.Minute*12)),
|
||||
c.GetDuration("firewall.conntrack.udp_timeout", time.Duration(time.Minute*3)),
|
||||
c.GetDuration("firewall.conntrack.default_timeout", time.Duration(time.Minute*10)),
|
||||
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
|
||||
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
|
||||
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
|
||||
nc,
|
||||
//TODO: max_connections
|
||||
)
|
||||
|
@ -499,12 +503,9 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
|
|||
|
||||
for i := startPort; i <= endPort; i++ {
|
||||
if _, ok := fp[i]; !ok {
|
||||
fp[i] = &FirewallRule{
|
||||
Groups: make([][]string, 0),
|
||||
Hosts: make(map[string]struct{}),
|
||||
CIDR: NewCIDRTree(),
|
||||
CANames: make(map[string]struct{}),
|
||||
CAShas: make(map[string]struct{}),
|
||||
fp[i] = &FirewallCA{
|
||||
CANames: make(map[string]*FirewallRule),
|
||||
CAShas: make(map[string]*FirewallRule),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -539,15 +540,83 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
|
|||
return fp[fwPortAny].match(p, c, caPool)
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
|
||||
if caName != "" {
|
||||
fr.CANames[caName] = struct{}{}
|
||||
func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
|
||||
// If there is an any rule then there is no need to establish specific ca rules
|
||||
if fc.Any != nil {
|
||||
return fc.Any.addRule(groups, host, ip)
|
||||
}
|
||||
|
||||
fr := func() *FirewallRule {
|
||||
return &FirewallRule{
|
||||
Hosts: make(map[string]struct{}),
|
||||
Groups: make([][]string, 0),
|
||||
CIDR: NewCIDRTree(),
|
||||
}
|
||||
}
|
||||
|
||||
any := false
|
||||
if caSha == "" && caName == "" {
|
||||
any = true
|
||||
}
|
||||
|
||||
if any {
|
||||
if fc.Any == nil {
|
||||
fc.Any = fr()
|
||||
}
|
||||
|
||||
// If it's any we need to wipe out any pre-existing rules to save on memory
|
||||
fc.CAShas = make(map[string]*FirewallRule)
|
||||
fc.CANames = make(map[string]*FirewallRule)
|
||||
return fc.Any.addRule(groups, host, ip)
|
||||
}
|
||||
|
||||
if caSha != "" {
|
||||
fr.CAShas[caSha] = struct{}{}
|
||||
if _, ok := fc.CAShas[caSha]; !ok {
|
||||
fc.CAShas[caSha] = fr()
|
||||
}
|
||||
err := fc.CAShas[caSha].addRule(groups, host, ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if caName != "" {
|
||||
if _, ok := fc.CANames[caName]; !ok {
|
||||
fc.CANames[caName] = fr()
|
||||
}
|
||||
err := fc.CANames[caName].addRule(groups, host, ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||
if fc == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if fc.Any != nil {
|
||||
return fc.Any.match(p, c)
|
||||
}
|
||||
|
||||
if t, ok := fc.CAShas[c.Details.Issuer]; ok {
|
||||
if t.match(p, c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
s, err := caPool.GetCAForCert(c)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return fc.CANames[s.Details.Name].match(p, c)
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
|
||||
if fr.Any {
|
||||
return nil
|
||||
}
|
||||
|
@ -593,28 +662,11 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
|
|||
return false
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||
func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool {
|
||||
if fr == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// CASha and CAName always need to be checked
|
||||
if len(fr.CAShas) > 0 {
|
||||
if _, ok := fr.CAShas[c.Details.Issuer]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if len(fr.CANames) > 0 {
|
||||
s, err := caPool.GetCAForCert(c)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if _, ok := fr.CANames[s.Details.Name]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Shortcut path for if groups, hosts, or cidr contained an `any`
|
||||
if fr.Any {
|
||||
return true
|
||||
|
@ -773,7 +825,7 @@ func setTCPRTTTracking(c *conn, p []byte) {
|
|||
ihl := int(p[0]&0x0f) << 2
|
||||
|
||||
// Don't track FIN packets
|
||||
if uint8(p[ihl+13])&tcpFIN != 0 {
|
||||
if p[ihl+13]&tcpFIN != 0 {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -787,7 +839,7 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
|
|||
}
|
||||
|
||||
ihl := int(p[0]&0x0f) << 2
|
||||
if uint8(p[ihl+13])&tcpACK == 0 {
|
||||
if p[ihl+13]&tcpACK == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"testing"
|
||||
|
@ -61,37 +62,37 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||
assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
|
||||
// Make sure an empty rule creates structure but doesn't allow anything to flow
|
||||
//TODO: ideally an empty rule would return an error
|
||||
assert.False(t, fw.InRules.TCP[1].Any)
|
||||
assert.Empty(t, fw.InRules.TCP[1].Groups)
|
||||
assert.Empty(t, fw.InRules.TCP[1].Hosts)
|
||||
assert.Nil(t, fw.InRules.TCP[1].CIDR.root.left)
|
||||
assert.Nil(t, fw.InRules.TCP[1].CIDR.root.right)
|
||||
assert.Nil(t, fw.InRules.TCP[1].CIDR.root.value)
|
||||
assert.False(t, fw.InRules.TCP[1].Any.Any)
|
||||
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
||||
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
||||
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left)
|
||||
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
|
||||
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
|
||||
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
|
||||
assert.False(t, fw.InRules.UDP[1].Any)
|
||||
assert.Contains(t, fw.InRules.UDP[1].Groups[0], "g1")
|
||||
assert.Empty(t, fw.InRules.UDP[1].Hosts)
|
||||
assert.Nil(t, fw.InRules.UDP[1].CIDR.root.left)
|
||||
assert.Nil(t, fw.InRules.UDP[1].CIDR.root.right)
|
||||
assert.Nil(t, fw.InRules.UDP[1].CIDR.root.value)
|
||||
assert.False(t, fw.InRules.UDP[1].Any.Any)
|
||||
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
|
||||
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
||||
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left)
|
||||
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
|
||||
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
|
||||
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
|
||||
assert.False(t, fw.InRules.ICMP[1].Any)
|
||||
assert.Empty(t, fw.InRules.ICMP[1].Groups)
|
||||
assert.Contains(t, fw.InRules.ICMP[1].Hosts, "h1")
|
||||
assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.left)
|
||||
assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.right)
|
||||
assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.value)
|
||||
assert.False(t, fw.InRules.ICMP[1].Any.Any)
|
||||
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
||||
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
||||
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left)
|
||||
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
|
||||
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
|
||||
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
|
||||
assert.False(t, fw.OutRules.AnyProto[1].Any)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[1].Groups)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[1].Hosts)
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].CIDR.Match(ip2int(ti.IP)))
|
||||
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
|
||||
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
|
||||
|
@ -104,28 +105,30 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||
// Set any and clear fields
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
|
||||
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Groups[0])
|
||||
assert.Contains(t, fw.OutRules.AnyProto[0].Hosts, "h1")
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[0].CIDR.Match(ip2int(ti.IP)))
|
||||
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
|
||||
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP)))
|
||||
|
||||
// run twice just to make sure
|
||||
//TODO: these ANY rules should clear the CA firewall portion
|
||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[0].Groups)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[0].Hosts)
|
||||
assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.left)
|
||||
assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.right)
|
||||
assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.value)
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
|
||||
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left)
|
||||
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
|
||||
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
|
||||
fmt.Printf("%+v\n", fw.OutRules.AnyProto[0])
|
||||
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any)
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
|
||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any)
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||
|
||||
// Test error conditions
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
|
@ -209,11 +212,11 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||
}
|
||||
|
||||
_, n, _ := net.ParseCIDR("172.1.1.1/32")
|
||||
ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
|
||||
ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
|
||||
ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
|
||||
ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
|
||||
ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
b.Run("fail on proto", func(b *testing.B) {
|
||||
|
@ -281,7 +284,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||
}
|
||||
})
|
||||
|
||||
ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
|
||||
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
|
||||
|
||||
b.Run("pass on ip with any port", func(b *testing.B) {
|
||||
ip := ip2int(net.IPv4(172, 1, 1, 1))
|
||||
|
|
Loading…
Reference in New Issue