From 5181cb0474933514c54c52ed41e1ba7e66e50dd4 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 2 Nov 2023 17:05:08 -0500 Subject: [PATCH] Use generics for CIDRTrees to avoid casting issues (#1004) --- allow_list.go | 43 ++++--------- allow_list_test.go | 2 +- calculated_remote.go | 4 +- cidr/tree4.go | 62 ++++++++++-------- cidr/tree4_test.go | 118 ++++++++++++++++++++-------------- cidr/tree6.go | 44 +++++++------ cidr/tree6_test.go | 73 +++++++++++++-------- firewall.go | 36 +++++++---- firewall_test.go | 12 ++-- hostmap.go | 4 +- lighthouse.go | 9 ++- overlay/route.go | 4 +- overlay/route_test.go | 18 +++--- overlay/tun_darwin.go | 8 +-- overlay/tun_freebsd.go | 10 +-- overlay/tun_linux.go | 14 ++-- overlay/tun_netbsd.go | 10 +-- overlay/tun_openbsd.go | 10 +-- overlay/tun_tester.go | 10 +-- overlay/tun_water_windows.go | 10 +-- overlay/tun_wintun_windows.go | 10 +-- 21 files changed, 264 insertions(+), 247 deletions(-) diff --git a/allow_list.go b/allow_list.go index 0e44a12..9186b2f 100644 --- a/allow_list.go +++ b/allow_list.go @@ -12,7 +12,7 @@ import ( type AllowList struct { // The values of this cidrTree are `bool`, signifying allow/deny - cidrTree *cidr.Tree6 + cidrTree *cidr.Tree6[bool] } type RemoteAllowList struct { @@ -20,7 +20,7 @@ type RemoteAllowList struct { // Inside Range Specific, keys of this tree are inside CIDRs and values // are *AllowList - insideAllowLists *cidr.Tree6 + insideAllowLists *cidr.Tree6[*AllowList] } type LocalAllowList struct { @@ -88,7 +88,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) } - tree := cidr.NewTree6() + tree := cidr.NewTree6[bool]() // Keep track of the rules we have added for both ipv4 and ipv6 type allowListRules struct { @@ -218,13 +218,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error return nameRules, nil } -func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) { +func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) { value := c.Get(k) if value == nil { return nil, nil } - remoteAllowRanges := cidr.NewTree6() + remoteAllowRanges := cidr.NewTree6[*AllowList]() rawMap, ok := value.(map[interface{}]interface{}) if !ok { @@ -257,13 +257,8 @@ func (al *AllowList) Allow(ip net.IP) bool { return true } - result := al.cidrTree.MostSpecificContains(ip) - switch v := result.(type) { - case bool: - return v - default: - panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result)) - } + _, result := al.cidrTree.MostSpecificContains(ip) + return result } func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool { @@ -271,13 +266,8 @@ func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool { return true } - result := al.cidrTree.MostSpecificContainsIpV4(ip) - switch v := result.(type) { - case bool: - return v - default: - panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result)) - } + _, result := al.cidrTree.MostSpecificContainsIpV4(ip) + return result } func (al *AllowList) AllowIpV6(hi, lo uint64) bool { @@ -285,13 +275,8 @@ func (al *AllowList) AllowIpV6(hi, lo uint64) bool { return true } - result := al.cidrTree.MostSpecificContainsIpV6(hi, lo) - switch v := result.(type) { - case bool: - return v - default: - panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result)) - } + _, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo) + return result } func (al *LocalAllowList) Allow(ip net.IP) bool { @@ -352,9 +337,9 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool { func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList { if al.insideAllowLists != nil { - inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) - if inside != nil { - return inside.(*AllowList) + ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) + if ok { + return inside } } return nil diff --git a/allow_list_test.go b/allow_list_test.go index 991b8a3..334cb60 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -100,7 +100,7 @@ func TestNewAllowListFromConfig(t *testing.T) { func TestAllowList_Allow(t *testing.T) { assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1"))) - tree := cidr.NewTree6() + tree := cidr.NewTree6[bool]() tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true) tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false) tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true) diff --git a/calculated_remote.go b/calculated_remote.go index 910f757..38f5bea 100644 --- a/calculated_remote.go +++ b/calculated_remote.go @@ -51,13 +51,13 @@ func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort { return &Ip4AndPort{Ip: uint32(masked), Port: c.port} } -func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4, error) { +func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) { value := c.Get(k) if value == nil { return nil, nil } - calculatedRemotes := cidr.NewTree4() + calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]() rawMap, ok := value.(map[any]any) if !ok { diff --git a/cidr/tree4.go b/cidr/tree4.go index 0839c90..fd4b358 100644 --- a/cidr/tree4.go +++ b/cidr/tree4.go @@ -6,35 +6,36 @@ import ( "github.com/slackhq/nebula/iputil" ) -type Node struct { - left *Node - right *Node - parent *Node - value interface{} +type Node[T any] struct { + left *Node[T] + right *Node[T] + parent *Node[T] + hasValue bool + value T } -type entry struct { +type entry[T any] struct { CIDR *net.IPNet - Value *interface{} + Value T } -type Tree4 struct { - root *Node - list []entry +type Tree4[T any] struct { + root *Node[T] + list []entry[T] } const ( startbit = iputil.VpnIp(0x80000000) ) -func NewTree4() *Tree4 { - tree := new(Tree4) - tree.root = &Node{} - tree.list = []entry{} +func NewTree4[T any]() *Tree4[T] { + tree := new(Tree4[T]) + tree.root = &Node[T]{} + tree.list = []entry[T]{} return tree } -func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) { +func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) { bit := startbit node := tree.root next := tree.root @@ -68,14 +69,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) { } } - tree.list = append(tree.list, entry{CIDR: cidr, Value: &val}) + tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) node.value = val + node.hasValue = true return } // Build up the rest of the tree we don't already have for bit&mask != 0 { - next = &Node{} + next = &Node[T]{} next.parent = node if ip&bit != 0 { @@ -90,17 +92,18 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) { // Final node marks our cidr, set the value node.value = val - tree.list = append(tree.list, entry{CIDR: cidr, Value: &val}) + node.hasValue = true + tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) } // Contains finds the first match, which may be the least specific -func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) { +func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) { bit := startbit node := tree.root for node != nil { - if node.value != nil { - return node.value + if node.hasValue { + return true, node.value } if ip&bit != 0 { @@ -113,17 +116,18 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) { } - return value + return false, value } // MostSpecificContains finds the most specific match -func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) { +func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) { bit := startbit node := tree.root for node != nil { - if node.value != nil { + if node.hasValue { value = node.value + ok = true } if ip&bit != 0 { @@ -135,11 +139,12 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) { bit >>= 1 } - return value + return ok, value } // Match finds the most specific match -func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) { +// TODO this is exact match +func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) { bit := startbit node := tree.root lastNode := node @@ -157,11 +162,12 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) { if bit == 0 && lastNode != nil { value = lastNode.value + ok = true } - return value + return ok, value } // List will return all CIDRs and their current values. Do not modify the contents! -func (tree *Tree4) List() []entry { +func (tree *Tree4[T]) List() []entry[T] { return tree.list } diff --git a/cidr/tree4_test.go b/cidr/tree4_test.go index dce8d54..acd403e 100644 --- a/cidr/tree4_test.go +++ b/cidr/tree4_test.go @@ -9,7 +9,7 @@ import ( ) func TestCIDRTree_List(t *testing.T) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("1.0.0.0/16"), "1") tree.AddCIDR(Parse("1.0.0.0/8"), "2") tree.AddCIDR(Parse("1.0.0.0/16"), "3") @@ -17,13 +17,13 @@ func TestCIDRTree_List(t *testing.T) { list := tree.List() assert.Len(t, list, 2) assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String()) - assert.Equal(t, "2", *list[0].Value) + assert.Equal(t, "2", list[0].Value) assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String()) - assert.Equal(t, "4", *list[1].Value) + assert.Equal(t, "4", list[1].Value) } func TestCIDRTree_Contains(t *testing.T) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("1.0.0.0/8"), "1") tree.AddCIDR(Parse("2.1.0.0/16"), "2") tree.AddCIDR(Parse("3.1.1.0/24"), "3") @@ -33,35 +33,43 @@ func TestCIDRTree_Contains(t *testing.T) { tree.AddCIDR(Parse("254.0.0.0/4"), "5") tests := []struct { + Found bool Result interface{} IP string }{ - {"1", "1.0.0.0"}, - {"1", "1.255.255.255"}, - {"2", "2.1.0.0"}, - {"2", "2.1.255.255"}, - {"3", "3.1.1.0"}, - {"3", "3.1.1.255"}, - {"4a", "4.1.1.255"}, - {"4a", "4.1.1.1"}, - {"5", "240.0.0.0"}, - {"5", "255.255.255.255"}, - {nil, "239.0.0.0"}, - {nil, "4.1.2.2"}, + {true, "1", "1.0.0.0"}, + {true, "1", "1.255.255.255"}, + {true, "2", "2.1.0.0"}, + {true, "2", "2.1.255.255"}, + {true, "3", "3.1.1.0"}, + {true, "3", "3.1.1.255"}, + {true, "4a", "4.1.1.255"}, + {true, "4a", "4.1.1.1"}, + {true, "5", "240.0.0.0"}, + {true, "5", "255.255.255.255"}, + {false, "", "239.0.0.0"}, + {false, "", "4.1.2.2"}, } for _, tt := range tests { - assert.Equal(t, tt.Result, tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))) + ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } - tree = NewTree4() + tree = NewTree4[string]() tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))) - assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))) + ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) + + ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) } func TestCIDRTree_MostSpecificContains(t *testing.T) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("1.0.0.0/8"), "1") tree.AddCIDR(Parse("2.1.0.0/16"), "2") tree.AddCIDR(Parse("3.1.1.0/24"), "3") @@ -71,59 +79,75 @@ func TestCIDRTree_MostSpecificContains(t *testing.T) { tree.AddCIDR(Parse("254.0.0.0/4"), "5") tests := []struct { + Found bool Result interface{} IP string }{ - {"1", "1.0.0.0"}, - {"1", "1.255.255.255"}, - {"2", "2.1.0.0"}, - {"2", "2.1.255.255"}, - {"3", "3.1.1.0"}, - {"3", "3.1.1.255"}, - {"4a", "4.1.1.255"}, - {"4b", "4.1.1.2"}, - {"4c", "4.1.1.1"}, - {"5", "240.0.0.0"}, - {"5", "255.255.255.255"}, - {nil, "239.0.0.0"}, - {nil, "4.1.2.2"}, + {true, "1", "1.0.0.0"}, + {true, "1", "1.255.255.255"}, + {true, "2", "2.1.0.0"}, + {true, "2", "2.1.255.255"}, + {true, "3", "3.1.1.0"}, + {true, "3", "3.1.1.255"}, + {true, "4a", "4.1.1.255"}, + {true, "4b", "4.1.1.2"}, + {true, "4c", "4.1.1.1"}, + {true, "5", "240.0.0.0"}, + {true, "5", "255.255.255.255"}, + {false, "", "239.0.0.0"}, + {false, "", "4.1.2.2"}, } for _, tt := range tests { - assert.Equal(t, tt.Result, tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))) + ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } - tree = NewTree4() + tree = NewTree4[string]() tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))) - assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))) + ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) + + ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) } func TestCIDRTree_Match(t *testing.T) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("4.1.1.0/32"), "1a") tree.AddCIDR(Parse("4.1.1.1/32"), "1b") tests := []struct { + Found bool Result interface{} IP string }{ - {"1a", "4.1.1.0"}, - {"1b", "4.1.1.1"}, + {true, "1a", "4.1.1.0"}, + {true, "1b", "4.1.1.1"}, } for _, tt := range tests { - assert.Equal(t, tt.Result, tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))) + ok, r := tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } - tree = NewTree4() + tree = NewTree4[string]() tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))) - assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))) + ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) + + ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) + assert.True(t, ok) + assert.Equal(t, "cool", r) } func BenchmarkCIDRTree_Contains(b *testing.B) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("1.1.0.0/16"), "1") tree.AddCIDR(Parse("1.2.1.1/32"), "1") tree.AddCIDR(Parse("192.2.1.1/32"), "1") @@ -145,7 +169,7 @@ func BenchmarkCIDRTree_Contains(b *testing.B) { } func BenchmarkCIDRTree_Match(b *testing.B) { - tree := NewTree4() + tree := NewTree4[string]() tree.AddCIDR(Parse("1.1.0.0/16"), "1") tree.AddCIDR(Parse("1.2.1.1/32"), "1") tree.AddCIDR(Parse("192.2.1.1/32"), "1") diff --git a/cidr/tree6.go b/cidr/tree6.go index d13c93d..3f2cd2a 100644 --- a/cidr/tree6.go +++ b/cidr/tree6.go @@ -8,20 +8,20 @@ import ( const startbit6 = uint64(1 << 63) -type Tree6 struct { - root4 *Node - root6 *Node +type Tree6[T any] struct { + root4 *Node[T] + root6 *Node[T] } -func NewTree6() *Tree6 { - tree := new(Tree6) - tree.root4 = &Node{} - tree.root6 = &Node{} +func NewTree6[T any]() *Tree6[T] { + tree := new(Tree6[T]) + tree.root4 = &Node[T]{} + tree.root6 = &Node[T]{} return tree } -func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) { - var node, next *Node +func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) { + var node, next *Node[T] cidrIP, ipv4 := isIPV4(cidr.IP) if ipv4 { @@ -56,7 +56,7 @@ func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) { // Build up the rest of the tree we don't already have for bit&mask != 0 { - next = &Node{} + next = &Node[T]{} next.parent = node if ip&bit != 0 { @@ -72,11 +72,12 @@ func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) { // Final node marks our cidr, set the value node.value = val + node.hasValue = true } // Finds the most specific match -func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) { - var node *Node +func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) { + var node *Node[T] wholeIP, ipv4 := isIPV4(ip) if ipv4 { @@ -90,8 +91,9 @@ func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) { bit := startbit for node != nil { - if node.value != nil { + if node.hasValue { value = node.value + ok = true } if bit == 0 { @@ -108,16 +110,17 @@ func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) { } } - return value + return ok, value } -func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) { +func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) { bit := startbit node := tree.root4 for node != nil { - if node.value != nil { + if node.hasValue { value = node.value + ok = true } if ip&bit != 0 { @@ -129,10 +132,10 @@ func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) bit >>= 1 } - return value + return ok, value } -func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) { +func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) { ip := hi node := tree.root6 @@ -140,8 +143,9 @@ func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) { bit := startbit6 for node != nil { - if node.value != nil { + if node.hasValue { value = node.value + ok = true } if bit == 0 { @@ -160,7 +164,7 @@ func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) { ip = lo } - return value + return ok, value } func isIPV4(ip net.IP) (net.IP, bool) { diff --git a/cidr/tree6_test.go b/cidr/tree6_test.go index b6dc4c2..eb159ec 100644 --- a/cidr/tree6_test.go +++ b/cidr/tree6_test.go @@ -9,7 +9,7 @@ import ( ) func TestCIDR6Tree_MostSpecificContains(t *testing.T) { - tree := NewTree6() + tree := NewTree6[string]() tree.AddCIDR(Parse("1.0.0.0/8"), "1") tree.AddCIDR(Parse("2.1.0.0/16"), "2") tree.AddCIDR(Parse("3.1.1.0/24"), "3") @@ -22,53 +22,68 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) { tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") tests := []struct { + Found bool Result interface{} IP string }{ - {"1", "1.0.0.0"}, - {"1", "1.255.255.255"}, - {"2", "2.1.0.0"}, - {"2", "2.1.255.255"}, - {"3", "3.1.1.0"}, - {"3", "3.1.1.255"}, - {"4a", "4.1.1.255"}, - {"4b", "4.1.1.2"}, - {"4c", "4.1.1.1"}, - {"5", "240.0.0.0"}, - {"5", "255.255.255.255"}, - {"6a", "1:2:0:4:1:1:1:1"}, - {"6b", "1:2:0:4:5:1:1:1"}, - {"6c", "1:2:0:4:5:0:0:0"}, - {nil, "239.0.0.0"}, - {nil, "4.1.2.2"}, + {true, "1", "1.0.0.0"}, + {true, "1", "1.255.255.255"}, + {true, "2", "2.1.0.0"}, + {true, "2", "2.1.255.255"}, + {true, "3", "3.1.1.0"}, + {true, "3", "3.1.1.255"}, + {true, "4a", "4.1.1.255"}, + {true, "4b", "4.1.1.2"}, + {true, "4c", "4.1.1.1"}, + {true, "5", "240.0.0.0"}, + {true, "5", "255.255.255.255"}, + {true, "6a", "1:2:0:4:1:1:1:1"}, + {true, "6b", "1:2:0:4:5:1:1:1"}, + {true, "6c", "1:2:0:4:5:0:0:0"}, + {false, "", "239.0.0.0"}, + {false, "", "4.1.2.2"}, } for _, tt := range tests { - assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP))) + ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP)) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } - tree = NewTree6() + tree = NewTree6[string]() tree.AddCIDR(Parse("1.1.1.1/0"), "cool") tree.AddCIDR(Parse("::/0"), "cool6") - assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0"))) - assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255"))) - assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::"))) - assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8"))) + ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0")) + assert.True(t, ok) + assert.Equal(t, "cool", r) + + ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255")) + assert.True(t, ok) + assert.Equal(t, "cool", r) + + ok, r = tree.MostSpecificContains(net.ParseIP("::")) + assert.True(t, ok) + assert.Equal(t, "cool6", r) + + ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8")) + assert.True(t, ok) + assert.Equal(t, "cool6", r) } func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) { - tree := NewTree6() + tree := NewTree6[string]() tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") tests := []struct { + Found bool Result interface{} IP string }{ - {"6a", "1:2:0:4:1:1:1:1"}, - {"6b", "1:2:0:4:5:1:1:1"}, - {"6c", "1:2:0:4:5:0:0:0"}, + {true, "6a", "1:2:0:4:1:1:1:1"}, + {true, "6b", "1:2:0:4:5:1:1:1"}, + {true, "6c", "1:2:0:4:5:0:0:0"}, } for _, tt := range tests { @@ -76,6 +91,8 @@ func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) { hi := binary.BigEndian.Uint64(ip[:8]) lo := binary.BigEndian.Uint64(ip[8:]) - assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(hi, lo)) + ok, r := tree.MostSpecificContainsIpV6(hi, lo) + assert.Equal(t, tt.Found, ok) + assert.Equal(t, tt.Result, r) } } diff --git a/firewall.go b/firewall.go index 93d940d..bdda18c 100644 --- a/firewall.go +++ b/firewall.go @@ -57,7 +57,7 @@ type Firewall struct { DefaultTimeout time.Duration //linux: 600s // Used to ensure we don't emit local packets for ips we don't own - localIps *cidr.Tree4 + localIps *cidr.Tree4[struct{}] rules string rulesVersion uint16 @@ -110,8 +110,8 @@ type FirewallRule struct { Any bool Hosts map[string]struct{} Groups [][]string - CIDR *cidr.Tree4 - LocalCIDR *cidr.Tree4 + CIDR *cidr.Tree4[struct{}] + LocalCIDR *cidr.Tree4[struct{}] } // Even though ports are uint16, int32 maps are faster for lookup @@ -137,7 +137,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D max = defaultTimeout } - localIps := cidr.NewTree4() + localIps := cidr.NewTree4[struct{}]() for _, ip := range c.Details.Ips { localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) } @@ -391,7 +391,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos // Make sure remote address matches nebula certificate if remoteCidr := h.remoteCidr; remoteCidr != nil { - if remoteCidr.Contains(fp.RemoteIP) == nil { + ok, _ := remoteCidr.Contains(fp.RemoteIP) + if !ok { f.metrics(incoming).droppedRemoteIP.Inc(1) return ErrInvalidRemoteIP } @@ -404,7 +405,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos } // Make sure we are supposed to be handling this local ip address - if f.localIps.Contains(fp.LocalIP) == nil { + ok, _ := f.localIps.Contains(fp.LocalIP) + if !ok { f.metrics(incoming).droppedLocalIP.Inc(1) return ErrInvalidLocalIP } @@ -657,8 +659,8 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN return &FirewallRule{ Hosts: make(map[string]struct{}), Groups: make([][]string, 0), - CIDR: cidr.NewTree4(), - LocalCIDR: cidr.NewTree4(), + CIDR: cidr.NewTree4[struct{}](), + LocalCIDR: cidr.NewTree4[struct{}](), } } @@ -726,8 +728,8 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, loc // If it's any we need to wipe out any pre-existing rules to save on memory fr.Groups = make([][]string, 0) fr.Hosts = make(map[string]struct{}) - fr.CIDR = cidr.NewTree4() - fr.LocalCIDR = cidr.NewTree4() + fr.CIDR = cidr.NewTree4[struct{}]() + fr.LocalCIDR = cidr.NewTree4[struct{}]() } else { if len(groups) > 0 { fr.Groups = append(fr.Groups, groups) @@ -809,12 +811,18 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool } } - if fr.CIDR != nil && fr.CIDR.Contains(p.RemoteIP) != nil { - return true + if fr.CIDR != nil { + ok, _ := fr.CIDR.Contains(p.RemoteIP) + if ok { + return true + } } - if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil { - return true + if fr.LocalCIDR != nil { + ok, _ := fr.LocalCIDR.Contains(p.LocalIP) + if ok { + return true + } } // No host, group, or cidr matched, bye bye diff --git a/firewall_test.go b/firewall_test.go index 7ffa747..83da899 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -92,14 +92,16 @@ func TestFirewall_AddRule(t *testing.T) { assert.False(t, fw.OutRules.AnyProto[1].Any.Any) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) - assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))) + ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)) + assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", "")) assert.False(t, fw.OutRules.AnyProto[1].Any.Any) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) - assert.NotNil(t, fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))) + ok, _ = fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)) + assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", "")) @@ -114,8 +116,10 @@ func TestFirewall_AddRule(t *testing.T) { assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", "")) assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0]) assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1") - assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))) - assert.NotNil(t, fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))) + ok, _ = fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)) + assert.True(t, ok) + ok, _ = fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)) + assert.True(t, ok) // run twice just to make sure //TODO: these ANY rules should clear the CA firewall portion diff --git a/hostmap.go b/hostmap.go index 4358632..df388cd 100644 --- a/hostmap.go +++ b/hostmap.go @@ -205,7 +205,7 @@ type HostInfo struct { localIndexId uint32 vpnIp iputil.VpnIp recvError atomic.Uint32 - remoteCidr *cidr.Tree4 + remoteCidr *cidr.Tree4[struct{}] relayState RelayState // HandshakePacket records the packets used to create this hostinfo @@ -633,7 +633,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { return } - remoteCidr := cidr.NewTree4() + remoteCidr := cidr.NewTree4[struct{}]() for _, ip := range c.Details.Ips { remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) } diff --git a/lighthouse.go b/lighthouse.go index 9b3b837..2193ad3 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -74,7 +74,7 @@ type LightHouse struct { // IP's of relays that can be used by peers to access me relaysForMe atomic.Pointer[[]iputil.VpnIp] - calculatedRemotes atomic.Pointer[cidr.Tree4] // Maps VpnIp to []*calculatedRemote + calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote metrics *MessageMetrics metricHolepunchTx metrics.Counter @@ -166,7 +166,7 @@ func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp { return *lh.relaysForMe.Load() } -func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4 { +func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] { return lh.calculatedRemotes.Load() } @@ -594,11 +594,10 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { if tree == nil { return false } - value := tree.MostSpecificContains(vpnIp) - if value == nil { + ok, calculatedRemotes := tree.MostSpecificContains(vpnIp) + if !ok { return false } - calculatedRemotes := value.([]*calculatedRemote) var calculated []*Ip4AndPort for _, cr := range calculatedRemotes { diff --git a/overlay/route.go b/overlay/route.go index 41c7a9c..793c8fd 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -21,8 +21,8 @@ type Route struct { Install bool } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) { - routeTree := cidr.NewTree4() +func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) { + routeTree := cidr.NewTree4[iputil.VpnIp]() for _, r := range routes { if !allowMTU && r.MTU > 0 { l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) diff --git a/overlay/route_test.go b/overlay/route_test.go index f83b5c1..46fb87c 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -265,18 +265,16 @@ func Test_makeRouteTree(t *testing.T) { assert.NoError(t, err) ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2")) - r := routeTree.MostSpecificContains(ip) - assert.NotNil(t, r) - assert.IsType(t, iputil.VpnIp(0), r) - assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r) + ok, r := routeTree.MostSpecificContains(ip) + assert.True(t, ok) + assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r) ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1")) - r = routeTree.MostSpecificContains(ip) - assert.NotNil(t, r) - assert.IsType(t, iputil.VpnIp(0), r) - assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r) + ok, r = routeTree.MostSpecificContains(ip) + assert.True(t, ok) + assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r) ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1")) - r = routeTree.MostSpecificContains(ip) - assert.Nil(t, r) + ok, r = routeTree.MostSpecificContains(ip) + assert.False(t, ok) } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 428e38f..caec580 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -25,7 +25,7 @@ type tun struct { cidr *net.IPNet DefaultMTU int Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] l *logrus.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata @@ -304,9 +304,9 @@ func (t *tun) Activate() error { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) + ok, r := t.routeTree.MostSpecificContains(ip) + if ok { + return r } return 0 diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 8a52954..338b8f6 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -48,7 +48,7 @@ type tun struct { cidr *net.IPNet MTU int Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] l *logrus.Logger io.ReadWriteCloser @@ -192,12 +192,8 @@ func (t *tun) Activate() error { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } func (t *tun) Cidr() *net.IPNet { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 8751a3f..a576bf3 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -30,7 +30,7 @@ type tun struct { TXQueueLen int Routes []Route - routeTree atomic.Pointer[cidr.Tree4] + routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] routeChan chan struct{} useSystemRoutes bool @@ -154,12 +154,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.Load().MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.Load().MostSpecificContains(ip) + return r } func (t *tun) Write(b []byte) (int, error) { @@ -380,7 +376,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { return } - newTree := cidr.NewTree4() + newTree := cidr.NewTree4[iputil.VpnIp]() if r.Type == unix.RTM_NEWROUTE { for _, oldR := range t.routeTree.Load().List() { newTree.AddCIDR(oldR.CIDR, oldR.Value) @@ -392,7 +388,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { } else { gw := iputil.Ip2VpnIp(r.Gw) for _, oldR := range t.routeTree.Load().List() { - if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw { + if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw { // This is the record to delete t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") continue diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 4d7f897..b1135fe 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -29,7 +29,7 @@ type tun struct { cidr *net.IPNet MTU int Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] l *logrus.Logger io.ReadWriteCloser @@ -134,12 +134,8 @@ func (t *tun) Activate() error { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } func (t *tun) Cidr() *net.IPNet { diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 709fb42..45c06dc 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -23,7 +23,7 @@ type tun struct { cidr *net.IPNet MTU int Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] l *logrus.Logger io.ReadWriteCloser @@ -115,12 +115,8 @@ func (t *tun) Activate() error { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } func (t *tun) Cidr() *net.IPNet { diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index a2a57e1..964315a 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -19,7 +19,7 @@ type TestTun struct { Device string cidr *net.IPNet Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] l *logrus.Logger closed atomic.Bool @@ -83,12 +83,8 @@ func (t *TestTun) Get(block bool) []byte { //********************************************************************************************************************// func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } func (t *TestTun) Activate() error { diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go index b1c28d6..e27cff2 100644 --- a/overlay/tun_water_windows.go +++ b/overlay/tun_water_windows.go @@ -18,7 +18,7 @@ type waterTun struct { cidr *net.IPNet MTU int Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] *water.Interface } @@ -97,12 +97,8 @@ func (t *waterTun) Activate() error { } func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } func (t *waterTun) Cidr() *net.IPNet { diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go index a406123..9647024 100644 --- a/overlay/tun_wintun_windows.go +++ b/overlay/tun_wintun_windows.go @@ -24,7 +24,7 @@ type winTun struct { prefix netip.Prefix MTU int Routes []Route - routeTree *cidr.Tree4 + routeTree *cidr.Tree4[iputil.VpnIp] tun *wintun.NativeTun } @@ -146,12 +146,8 @@ func (t *winTun) Activate() error { } func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) - if r != nil { - return r.(iputil.VpnIp) - } - - return 0 + _, r := t.routeTree.MostSpecificContains(ip) + return r } func (t *winTun) Cidr() *net.IPNet {