mirror of https://github.com/slackhq/nebula.git
Use generics for CIDRTrees to avoid casting issues (#1004)
This commit is contained in:
parent
a44e1b8b05
commit
5181cb0474
|
@ -12,7 +12,7 @@ import (
|
||||||
|
|
||||||
type AllowList struct {
|
type AllowList struct {
|
||||||
// The values of this cidrTree are `bool`, signifying allow/deny
|
// The values of this cidrTree are `bool`, signifying allow/deny
|
||||||
cidrTree *cidr.Tree6
|
cidrTree *cidr.Tree6[bool]
|
||||||
}
|
}
|
||||||
|
|
||||||
type RemoteAllowList struct {
|
type RemoteAllowList struct {
|
||||||
|
@ -20,7 +20,7 @@ type RemoteAllowList struct {
|
||||||
|
|
||||||
// Inside Range Specific, keys of this tree are inside CIDRs and values
|
// Inside Range Specific, keys of this tree are inside CIDRs and values
|
||||||
// are *AllowList
|
// are *AllowList
|
||||||
insideAllowLists *cidr.Tree6
|
insideAllowLists *cidr.Tree6[*AllowList]
|
||||||
}
|
}
|
||||||
|
|
||||||
type LocalAllowList struct {
|
type LocalAllowList struct {
|
||||||
|
@ -88,7 +88,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
|
||||||
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
|
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
tree := cidr.NewTree6()
|
tree := cidr.NewTree6[bool]()
|
||||||
|
|
||||||
// Keep track of the rules we have added for both ipv4 and ipv6
|
// Keep track of the rules we have added for both ipv4 and ipv6
|
||||||
type allowListRules struct {
|
type allowListRules struct {
|
||||||
|
@ -218,13 +218,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error
|
||||||
return nameRules, nil
|
return nameRules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) {
|
func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) {
|
||||||
value := c.Get(k)
|
value := c.Get(k)
|
||||||
if value == nil {
|
if value == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteAllowRanges := cidr.NewTree6()
|
remoteAllowRanges := cidr.NewTree6[*AllowList]()
|
||||||
|
|
||||||
rawMap, ok := value.(map[interface{}]interface{})
|
rawMap, ok := value.(map[interface{}]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -257,13 +257,8 @@ func (al *AllowList) Allow(ip net.IP) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
result := al.cidrTree.MostSpecificContains(ip)
|
_, result := al.cidrTree.MostSpecificContains(ip)
|
||||||
switch v := result.(type) {
|
return result
|
||||||
case bool:
|
|
||||||
return v
|
|
||||||
default:
|
|
||||||
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
|
func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
|
||||||
|
@ -271,13 +266,8 @@ func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
result := al.cidrTree.MostSpecificContainsIpV4(ip)
|
_, result := al.cidrTree.MostSpecificContainsIpV4(ip)
|
||||||
switch v := result.(type) {
|
return result
|
||||||
case bool:
|
|
||||||
return v
|
|
||||||
default:
|
|
||||||
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
|
func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
|
||||||
|
@ -285,13 +275,8 @@ func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
|
_, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
|
||||||
switch v := result.(type) {
|
return result
|
||||||
case bool:
|
|
||||||
return v
|
|
||||||
default:
|
|
||||||
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *LocalAllowList) Allow(ip net.IP) bool {
|
func (al *LocalAllowList) Allow(ip net.IP) bool {
|
||||||
|
@ -352,9 +337,9 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
|
||||||
|
|
||||||
func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
|
func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
|
||||||
if al.insideAllowLists != nil {
|
if al.insideAllowLists != nil {
|
||||||
inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
|
ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
|
||||||
if inside != nil {
|
if ok {
|
||||||
return inside.(*AllowList)
|
return inside
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -100,7 +100,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
|
||||||
func TestAllowList_Allow(t *testing.T) {
|
func TestAllowList_Allow(t *testing.T) {
|
||||||
assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
|
assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
|
||||||
|
|
||||||
tree := cidr.NewTree6()
|
tree := cidr.NewTree6[bool]()
|
||||||
tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
|
tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
|
||||||
tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
|
tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
|
||||||
tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
|
tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
|
||||||
|
|
|
@ -51,13 +51,13 @@ func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort {
|
||||||
return &Ip4AndPort{Ip: uint32(masked), Port: c.port}
|
return &Ip4AndPort{Ip: uint32(masked), Port: c.port}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4, error) {
|
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) {
|
||||||
value := c.Get(k)
|
value := c.Get(k)
|
||||||
if value == nil {
|
if value == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
calculatedRemotes := cidr.NewTree4()
|
calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]()
|
||||||
|
|
||||||
rawMap, ok := value.(map[any]any)
|
rawMap, ok := value.(map[any]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|
|
@ -6,35 +6,36 @@ import (
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Node struct {
|
type Node[T any] struct {
|
||||||
left *Node
|
left *Node[T]
|
||||||
right *Node
|
right *Node[T]
|
||||||
parent *Node
|
parent *Node[T]
|
||||||
value interface{}
|
hasValue bool
|
||||||
|
value T
|
||||||
}
|
}
|
||||||
|
|
||||||
type entry struct {
|
type entry[T any] struct {
|
||||||
CIDR *net.IPNet
|
CIDR *net.IPNet
|
||||||
Value *interface{}
|
Value T
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tree4 struct {
|
type Tree4[T any] struct {
|
||||||
root *Node
|
root *Node[T]
|
||||||
list []entry
|
list []entry[T]
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
startbit = iputil.VpnIp(0x80000000)
|
startbit = iputil.VpnIp(0x80000000)
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewTree4() *Tree4 {
|
func NewTree4[T any]() *Tree4[T] {
|
||||||
tree := new(Tree4)
|
tree := new(Tree4[T])
|
||||||
tree.root = &Node{}
|
tree.root = &Node[T]{}
|
||||||
tree.list = []entry{}
|
tree.list = []entry[T]{}
|
||||||
return tree
|
return tree
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
|
func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
|
||||||
bit := startbit
|
bit := startbit
|
||||||
node := tree.root
|
node := tree.root
|
||||||
next := tree.root
|
next := tree.root
|
||||||
|
@ -68,14 +69,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
|
tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
|
||||||
node.value = val
|
node.value = val
|
||||||
|
node.hasValue = true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build up the rest of the tree we don't already have
|
// Build up the rest of the tree we don't already have
|
||||||
for bit&mask != 0 {
|
for bit&mask != 0 {
|
||||||
next = &Node{}
|
next = &Node[T]{}
|
||||||
next.parent = node
|
next.parent = node
|
||||||
|
|
||||||
if ip&bit != 0 {
|
if ip&bit != 0 {
|
||||||
|
@ -90,17 +92,18 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
|
||||||
|
|
||||||
// Final node marks our cidr, set the value
|
// Final node marks our cidr, set the value
|
||||||
node.value = val
|
node.value = val
|
||||||
tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
|
node.hasValue = true
|
||||||
|
tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Contains finds the first match, which may be the least specific
|
// Contains finds the first match, which may be the least specific
|
||||||
func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
|
func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
|
||||||
bit := startbit
|
bit := startbit
|
||||||
node := tree.root
|
node := tree.root
|
||||||
|
|
||||||
for node != nil {
|
for node != nil {
|
||||||
if node.value != nil {
|
if node.hasValue {
|
||||||
return node.value
|
return true, node.value
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip&bit != 0 {
|
if ip&bit != 0 {
|
||||||
|
@ -113,17 +116,18 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return value
|
return false, value
|
||||||
}
|
}
|
||||||
|
|
||||||
// MostSpecificContains finds the most specific match
|
// MostSpecificContains finds the most specific match
|
||||||
func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
|
func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
|
||||||
bit := startbit
|
bit := startbit
|
||||||
node := tree.root
|
node := tree.root
|
||||||
|
|
||||||
for node != nil {
|
for node != nil {
|
||||||
if node.value != nil {
|
if node.hasValue {
|
||||||
value = node.value
|
value = node.value
|
||||||
|
ok = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip&bit != 0 {
|
if ip&bit != 0 {
|
||||||
|
@ -135,11 +139,12 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
|
||||||
bit >>= 1
|
bit >>= 1
|
||||||
}
|
}
|
||||||
|
|
||||||
return value
|
return ok, value
|
||||||
}
|
}
|
||||||
|
|
||||||
// Match finds the most specific match
|
// Match finds the most specific match
|
||||||
func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
|
// TODO this is exact match
|
||||||
|
func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
|
||||||
bit := startbit
|
bit := startbit
|
||||||
node := tree.root
|
node := tree.root
|
||||||
lastNode := node
|
lastNode := node
|
||||||
|
@ -157,11 +162,12 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
|
||||||
|
|
||||||
if bit == 0 && lastNode != nil {
|
if bit == 0 && lastNode != nil {
|
||||||
value = lastNode.value
|
value = lastNode.value
|
||||||
|
ok = true
|
||||||
}
|
}
|
||||||
return value
|
return ok, value
|
||||||
}
|
}
|
||||||
|
|
||||||
// List will return all CIDRs and their current values. Do not modify the contents!
|
// List will return all CIDRs and their current values. Do not modify the contents!
|
||||||
func (tree *Tree4) List() []entry {
|
func (tree *Tree4[T]) List() []entry[T] {
|
||||||
return tree.list
|
return tree.list
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCIDRTree_List(t *testing.T) {
|
func TestCIDRTree_List(t *testing.T) {
|
||||||
tree := NewTree4()
|
tree := NewTree4[string]()
|
||||||
tree.AddCIDR(Parse("1.0.0.0/16"), "1")
|
tree.AddCIDR(Parse("1.0.0.0/16"), "1")
|
||||||
tree.AddCIDR(Parse("1.0.0.0/8"), "2")
|
tree.AddCIDR(Parse("1.0.0.0/8"), "2")
|
||||||
tree.AddCIDR(Parse("1.0.0.0/16"), "3")
|
tree.AddCIDR(Parse("1.0.0.0/16"), "3")
|
||||||
|
@ -17,13 +17,13 @@ func TestCIDRTree_List(t *testing.T) {
|
||||||
list := tree.List()
|
list := tree.List()
|
||||||
assert.Len(t, list, 2)
|
assert.Len(t, list, 2)
|
||||||
assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
|
assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
|
||||||
assert.Equal(t, "2", *list[0].Value)
|
assert.Equal(t, "2", list[0].Value)
|
||||||
assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
|
assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
|
||||||
assert.Equal(t, "4", *list[1].Value)
|
assert.Equal(t, "4", list[1].Value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCIDRTree_Contains(t *testing.T) {
|
func TestCIDRTree_Contains(t *testing.T) {
|
||||||
tree := NewTree4()
|
tree := NewTree4[string]()
|
||||||
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
||||||
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
||||||
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
||||||
|
@ -33,35 +33,43 @@ func TestCIDRTree_Contains(t *testing.T) {
|
||||||
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
|
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
Found bool
|
||||||
Result interface{}
|
Result interface{}
|
||||||
IP string
|
IP string
|
||||||
}{
|
}{
|
||||||
{"1", "1.0.0.0"},
|
{true, "1", "1.0.0.0"},
|
||||||
{"1", "1.255.255.255"},
|
{true, "1", "1.255.255.255"},
|
||||||
{"2", "2.1.0.0"},
|
{true, "2", "2.1.0.0"},
|
||||||
{"2", "2.1.255.255"},
|
{true, "2", "2.1.255.255"},
|
||||||
{"3", "3.1.1.0"},
|
{true, "3", "3.1.1.0"},
|
||||||
{"3", "3.1.1.255"},
|
{true, "3", "3.1.1.255"},
|
||||||
{"4a", "4.1.1.255"},
|
{true, "4a", "4.1.1.255"},
|
||||||
{"4a", "4.1.1.1"},
|
{true, "4a", "4.1.1.1"},
|
||||||
{"5", "240.0.0.0"},
|
{true, "5", "240.0.0.0"},
|
||||||
{"5", "255.255.255.255"},
|
{true, "5", "255.255.255.255"},
|
||||||
{nil, "239.0.0.0"},
|
{false, "", "239.0.0.0"},
|
||||||
{nil, "4.1.2.2"},
|
{false, "", "4.1.2.2"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
assert.Equal(t, tt.Result, tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
|
ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
|
||||||
|
assert.Equal(t, tt.Found, ok)
|
||||||
|
assert.Equal(t, tt.Result, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
tree = NewTree4()
|
tree = NewTree4[string]()
|
||||||
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
||||||
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
|
ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
|
||||||
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "cool", r)
|
||||||
|
|
||||||
|
ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "cool", r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCIDRTree_MostSpecificContains(t *testing.T) {
|
func TestCIDRTree_MostSpecificContains(t *testing.T) {
|
||||||
tree := NewTree4()
|
tree := NewTree4[string]()
|
||||||
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
||||||
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
||||||
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
||||||
|
@ -71,59 +79,75 @@ func TestCIDRTree_MostSpecificContains(t *testing.T) {
|
||||||
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
|
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
Found bool
|
||||||
Result interface{}
|
Result interface{}
|
||||||
IP string
|
IP string
|
||||||
}{
|
}{
|
||||||
{"1", "1.0.0.0"},
|
{true, "1", "1.0.0.0"},
|
||||||
{"1", "1.255.255.255"},
|
{true, "1", "1.255.255.255"},
|
||||||
{"2", "2.1.0.0"},
|
{true, "2", "2.1.0.0"},
|
||||||
{"2", "2.1.255.255"},
|
{true, "2", "2.1.255.255"},
|
||||||
{"3", "3.1.1.0"},
|
{true, "3", "3.1.1.0"},
|
||||||
{"3", "3.1.1.255"},
|
{true, "3", "3.1.1.255"},
|
||||||
{"4a", "4.1.1.255"},
|
{true, "4a", "4.1.1.255"},
|
||||||
{"4b", "4.1.1.2"},
|
{true, "4b", "4.1.1.2"},
|
||||||
{"4c", "4.1.1.1"},
|
{true, "4c", "4.1.1.1"},
|
||||||
{"5", "240.0.0.0"},
|
{true, "5", "240.0.0.0"},
|
||||||
{"5", "255.255.255.255"},
|
{true, "5", "255.255.255.255"},
|
||||||
{nil, "239.0.0.0"},
|
{false, "", "239.0.0.0"},
|
||||||
{nil, "4.1.2.2"},
|
{false, "", "4.1.2.2"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
assert.Equal(t, tt.Result, tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
|
ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
|
||||||
|
assert.Equal(t, tt.Found, ok)
|
||||||
|
assert.Equal(t, tt.Result, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
tree = NewTree4()
|
tree = NewTree4[string]()
|
||||||
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
||||||
assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
|
ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
|
||||||
assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "cool", r)
|
||||||
|
|
||||||
|
ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "cool", r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCIDRTree_Match(t *testing.T) {
|
func TestCIDRTree_Match(t *testing.T) {
|
||||||
tree := NewTree4()
|
tree := NewTree4[string]()
|
||||||
tree.AddCIDR(Parse("4.1.1.0/32"), "1a")
|
tree.AddCIDR(Parse("4.1.1.0/32"), "1a")
|
||||||
tree.AddCIDR(Parse("4.1.1.1/32"), "1b")
|
tree.AddCIDR(Parse("4.1.1.1/32"), "1b")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
Found bool
|
||||||
Result interface{}
|
Result interface{}
|
||||||
IP string
|
IP string
|
||||||
}{
|
}{
|
||||||
{"1a", "4.1.1.0"},
|
{true, "1a", "4.1.1.0"},
|
||||||
{"1b", "4.1.1.1"},
|
{true, "1b", "4.1.1.1"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
assert.Equal(t, tt.Result, tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
|
ok, r := tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
|
||||||
|
assert.Equal(t, tt.Found, ok)
|
||||||
|
assert.Equal(t, tt.Result, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
tree = NewTree4()
|
tree = NewTree4[string]()
|
||||||
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
||||||
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
|
ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
|
||||||
assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "cool", r)
|
||||||
|
|
||||||
|
ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "cool", r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkCIDRTree_Contains(b *testing.B) {
|
func BenchmarkCIDRTree_Contains(b *testing.B) {
|
||||||
tree := NewTree4()
|
tree := NewTree4[string]()
|
||||||
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
|
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
|
||||||
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
|
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
|
||||||
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
|
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
|
||||||
|
@ -145,7 +169,7 @@ func BenchmarkCIDRTree_Contains(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkCIDRTree_Match(b *testing.B) {
|
func BenchmarkCIDRTree_Match(b *testing.B) {
|
||||||
tree := NewTree4()
|
tree := NewTree4[string]()
|
||||||
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
|
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
|
||||||
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
|
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
|
||||||
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
|
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
|
||||||
|
|
|
@ -8,20 +8,20 @@ import (
|
||||||
|
|
||||||
const startbit6 = uint64(1 << 63)
|
const startbit6 = uint64(1 << 63)
|
||||||
|
|
||||||
type Tree6 struct {
|
type Tree6[T any] struct {
|
||||||
root4 *Node
|
root4 *Node[T]
|
||||||
root6 *Node
|
root6 *Node[T]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTree6() *Tree6 {
|
func NewTree6[T any]() *Tree6[T] {
|
||||||
tree := new(Tree6)
|
tree := new(Tree6[T])
|
||||||
tree.root4 = &Node{}
|
tree.root4 = &Node[T]{}
|
||||||
tree.root6 = &Node{}
|
tree.root6 = &Node[T]{}
|
||||||
return tree
|
return tree
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
|
func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) {
|
||||||
var node, next *Node
|
var node, next *Node[T]
|
||||||
|
|
||||||
cidrIP, ipv4 := isIPV4(cidr.IP)
|
cidrIP, ipv4 := isIPV4(cidr.IP)
|
||||||
if ipv4 {
|
if ipv4 {
|
||||||
|
@ -56,7 +56,7 @@ func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
|
||||||
|
|
||||||
// Build up the rest of the tree we don't already have
|
// Build up the rest of the tree we don't already have
|
||||||
for bit&mask != 0 {
|
for bit&mask != 0 {
|
||||||
next = &Node{}
|
next = &Node[T]{}
|
||||||
next.parent = node
|
next.parent = node
|
||||||
|
|
||||||
if ip&bit != 0 {
|
if ip&bit != 0 {
|
||||||
|
@ -72,11 +72,12 @@ func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
|
||||||
|
|
||||||
// Final node marks our cidr, set the value
|
// Final node marks our cidr, set the value
|
||||||
node.value = val
|
node.value = val
|
||||||
|
node.hasValue = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finds the most specific match
|
// Finds the most specific match
|
||||||
func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
|
func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) {
|
||||||
var node *Node
|
var node *Node[T]
|
||||||
|
|
||||||
wholeIP, ipv4 := isIPV4(ip)
|
wholeIP, ipv4 := isIPV4(ip)
|
||||||
if ipv4 {
|
if ipv4 {
|
||||||
|
@ -90,8 +91,9 @@ func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
|
||||||
bit := startbit
|
bit := startbit
|
||||||
|
|
||||||
for node != nil {
|
for node != nil {
|
||||||
if node.value != nil {
|
if node.hasValue {
|
||||||
value = node.value
|
value = node.value
|
||||||
|
ok = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if bit == 0 {
|
if bit == 0 {
|
||||||
|
@ -108,16 +110,17 @@ func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return value
|
return ok, value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) {
|
func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) {
|
||||||
bit := startbit
|
bit := startbit
|
||||||
node := tree.root4
|
node := tree.root4
|
||||||
|
|
||||||
for node != nil {
|
for node != nil {
|
||||||
if node.value != nil {
|
if node.hasValue {
|
||||||
value = node.value
|
value = node.value
|
||||||
|
ok = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip&bit != 0 {
|
if ip&bit != 0 {
|
||||||
|
@ -129,10 +132,10 @@ func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{})
|
||||||
bit >>= 1
|
bit >>= 1
|
||||||
}
|
}
|
||||||
|
|
||||||
return value
|
return ok, value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
|
func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) {
|
||||||
ip := hi
|
ip := hi
|
||||||
node := tree.root6
|
node := tree.root6
|
||||||
|
|
||||||
|
@ -140,8 +143,9 @@ func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
|
||||||
bit := startbit6
|
bit := startbit6
|
||||||
|
|
||||||
for node != nil {
|
for node != nil {
|
||||||
if node.value != nil {
|
if node.hasValue {
|
||||||
value = node.value
|
value = node.value
|
||||||
|
ok = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if bit == 0 {
|
if bit == 0 {
|
||||||
|
@ -160,7 +164,7 @@ func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
|
||||||
ip = lo
|
ip = lo
|
||||||
}
|
}
|
||||||
|
|
||||||
return value
|
return ok, value
|
||||||
}
|
}
|
||||||
|
|
||||||
func isIPV4(ip net.IP) (net.IP, bool) {
|
func isIPV4(ip net.IP) (net.IP, bool) {
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
|
func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
|
||||||
tree := NewTree6()
|
tree := NewTree6[string]()
|
||||||
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
||||||
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
||||||
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
||||||
|
@ -22,53 +22,68 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
|
||||||
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
|
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
Found bool
|
||||||
Result interface{}
|
Result interface{}
|
||||||
IP string
|
IP string
|
||||||
}{
|
}{
|
||||||
{"1", "1.0.0.0"},
|
{true, "1", "1.0.0.0"},
|
||||||
{"1", "1.255.255.255"},
|
{true, "1", "1.255.255.255"},
|
||||||
{"2", "2.1.0.0"},
|
{true, "2", "2.1.0.0"},
|
||||||
{"2", "2.1.255.255"},
|
{true, "2", "2.1.255.255"},
|
||||||
{"3", "3.1.1.0"},
|
{true, "3", "3.1.1.0"},
|
||||||
{"3", "3.1.1.255"},
|
{true, "3", "3.1.1.255"},
|
||||||
{"4a", "4.1.1.255"},
|
{true, "4a", "4.1.1.255"},
|
||||||
{"4b", "4.1.1.2"},
|
{true, "4b", "4.1.1.2"},
|
||||||
{"4c", "4.1.1.1"},
|
{true, "4c", "4.1.1.1"},
|
||||||
{"5", "240.0.0.0"},
|
{true, "5", "240.0.0.0"},
|
||||||
{"5", "255.255.255.255"},
|
{true, "5", "255.255.255.255"},
|
||||||
{"6a", "1:2:0:4:1:1:1:1"},
|
{true, "6a", "1:2:0:4:1:1:1:1"},
|
||||||
{"6b", "1:2:0:4:5:1:1:1"},
|
{true, "6b", "1:2:0:4:5:1:1:1"},
|
||||||
{"6c", "1:2:0:4:5:0:0:0"},
|
{true, "6c", "1:2:0:4:5:0:0:0"},
|
||||||
{nil, "239.0.0.0"},
|
{false, "", "239.0.0.0"},
|
||||||
{nil, "4.1.2.2"},
|
{false, "", "4.1.2.2"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP)))
|
ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP))
|
||||||
|
assert.Equal(t, tt.Found, ok)
|
||||||
|
assert.Equal(t, tt.Result, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
tree = NewTree6()
|
tree = NewTree6[string]()
|
||||||
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
||||||
tree.AddCIDR(Parse("::/0"), "cool6")
|
tree.AddCIDR(Parse("::/0"), "cool6")
|
||||||
assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0")))
|
ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0"))
|
||||||
assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255")))
|
assert.True(t, ok)
|
||||||
assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::")))
|
assert.Equal(t, "cool", r)
|
||||||
assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8")))
|
|
||||||
|
ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255"))
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "cool", r)
|
||||||
|
|
||||||
|
ok, r = tree.MostSpecificContains(net.ParseIP("::"))
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "cool6", r)
|
||||||
|
|
||||||
|
ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8"))
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "cool6", r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
|
func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
|
||||||
tree := NewTree6()
|
tree := NewTree6[string]()
|
||||||
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
|
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
|
||||||
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
|
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
|
||||||
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
|
tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
Found bool
|
||||||
Result interface{}
|
Result interface{}
|
||||||
IP string
|
IP string
|
||||||
}{
|
}{
|
||||||
{"6a", "1:2:0:4:1:1:1:1"},
|
{true, "6a", "1:2:0:4:1:1:1:1"},
|
||||||
{"6b", "1:2:0:4:5:1:1:1"},
|
{true, "6b", "1:2:0:4:5:1:1:1"},
|
||||||
{"6c", "1:2:0:4:5:0:0:0"},
|
{true, "6c", "1:2:0:4:5:0:0:0"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -76,6 +91,8 @@ func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
|
||||||
hi := binary.BigEndian.Uint64(ip[:8])
|
hi := binary.BigEndian.Uint64(ip[:8])
|
||||||
lo := binary.BigEndian.Uint64(ip[8:])
|
lo := binary.BigEndian.Uint64(ip[8:])
|
||||||
|
|
||||||
assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(hi, lo))
|
ok, r := tree.MostSpecificContainsIpV6(hi, lo)
|
||||||
|
assert.Equal(t, tt.Found, ok)
|
||||||
|
assert.Equal(t, tt.Result, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
36
firewall.go
36
firewall.go
|
@ -57,7 +57,7 @@ type Firewall struct {
|
||||||
DefaultTimeout time.Duration //linux: 600s
|
DefaultTimeout time.Duration //linux: 600s
|
||||||
|
|
||||||
// Used to ensure we don't emit local packets for ips we don't own
|
// Used to ensure we don't emit local packets for ips we don't own
|
||||||
localIps *cidr.Tree4
|
localIps *cidr.Tree4[struct{}]
|
||||||
|
|
||||||
rules string
|
rules string
|
||||||
rulesVersion uint16
|
rulesVersion uint16
|
||||||
|
@ -110,8 +110,8 @@ type FirewallRule struct {
|
||||||
Any bool
|
Any bool
|
||||||
Hosts map[string]struct{}
|
Hosts map[string]struct{}
|
||||||
Groups [][]string
|
Groups [][]string
|
||||||
CIDR *cidr.Tree4
|
CIDR *cidr.Tree4[struct{}]
|
||||||
LocalCIDR *cidr.Tree4
|
LocalCIDR *cidr.Tree4[struct{}]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Even though ports are uint16, int32 maps are faster for lookup
|
// Even though ports are uint16, int32 maps are faster for lookup
|
||||||
|
@ -137,7 +137,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||||
max = defaultTimeout
|
max = defaultTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
localIps := cidr.NewTree4()
|
localIps := cidr.NewTree4[struct{}]()
|
||||||
for _, ip := range c.Details.Ips {
|
for _, ip := range c.Details.Ips {
|
||||||
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
||||||
}
|
}
|
||||||
|
@ -391,7 +391,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
|
||||||
|
|
||||||
// Make sure remote address matches nebula certificate
|
// Make sure remote address matches nebula certificate
|
||||||
if remoteCidr := h.remoteCidr; remoteCidr != nil {
|
if remoteCidr := h.remoteCidr; remoteCidr != nil {
|
||||||
if remoteCidr.Contains(fp.RemoteIP) == nil {
|
ok, _ := remoteCidr.Contains(fp.RemoteIP)
|
||||||
|
if !ok {
|
||||||
f.metrics(incoming).droppedRemoteIP.Inc(1)
|
f.metrics(incoming).droppedRemoteIP.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
|
@ -404,7 +405,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we are supposed to be handling this local ip address
|
// Make sure we are supposed to be handling this local ip address
|
||||||
if f.localIps.Contains(fp.LocalIP) == nil {
|
ok, _ := f.localIps.Contains(fp.LocalIP)
|
||||||
|
if !ok {
|
||||||
f.metrics(incoming).droppedLocalIP.Inc(1)
|
f.metrics(incoming).droppedLocalIP.Inc(1)
|
||||||
return ErrInvalidLocalIP
|
return ErrInvalidLocalIP
|
||||||
}
|
}
|
||||||
|
@ -657,8 +659,8 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
|
||||||
return &FirewallRule{
|
return &FirewallRule{
|
||||||
Hosts: make(map[string]struct{}),
|
Hosts: make(map[string]struct{}),
|
||||||
Groups: make([][]string, 0),
|
Groups: make([][]string, 0),
|
||||||
CIDR: cidr.NewTree4(),
|
CIDR: cidr.NewTree4[struct{}](),
|
||||||
LocalCIDR: cidr.NewTree4(),
|
LocalCIDR: cidr.NewTree4[struct{}](),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -726,8 +728,8 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, loc
|
||||||
// If it's any we need to wipe out any pre-existing rules to save on memory
|
// If it's any we need to wipe out any pre-existing rules to save on memory
|
||||||
fr.Groups = make([][]string, 0)
|
fr.Groups = make([][]string, 0)
|
||||||
fr.Hosts = make(map[string]struct{})
|
fr.Hosts = make(map[string]struct{})
|
||||||
fr.CIDR = cidr.NewTree4()
|
fr.CIDR = cidr.NewTree4[struct{}]()
|
||||||
fr.LocalCIDR = cidr.NewTree4()
|
fr.LocalCIDR = cidr.NewTree4[struct{}]()
|
||||||
} else {
|
} else {
|
||||||
if len(groups) > 0 {
|
if len(groups) > 0 {
|
||||||
fr.Groups = append(fr.Groups, groups)
|
fr.Groups = append(fr.Groups, groups)
|
||||||
|
@ -809,12 +811,18 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if fr.CIDR != nil && fr.CIDR.Contains(p.RemoteIP) != nil {
|
if fr.CIDR != nil {
|
||||||
return true
|
ok, _ := fr.CIDR.Contains(p.RemoteIP)
|
||||||
|
if ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil {
|
if fr.LocalCIDR != nil {
|
||||||
return true
|
ok, _ := fr.LocalCIDR.Contains(p.LocalIP)
|
||||||
|
if ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// No host, group, or cidr matched, bye bye
|
// No host, group, or cidr matched, bye bye
|
||||||
|
|
|
@ -92,14 +92,16 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||||
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
|
ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
|
||||||
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
|
ok, _ = fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
|
||||||
|
@ -114,8 +116,10 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
|
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
|
||||||
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
|
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
|
||||||
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
|
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
|
ok, _ = fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
|
assert.True(t, ok)
|
||||||
|
ok, _ = fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
// run twice just to make sure
|
// run twice just to make sure
|
||||||
//TODO: these ANY rules should clear the CA firewall portion
|
//TODO: these ANY rules should clear the CA firewall portion
|
||||||
|
|
|
@ -205,7 +205,7 @@ type HostInfo struct {
|
||||||
localIndexId uint32
|
localIndexId uint32
|
||||||
vpnIp iputil.VpnIp
|
vpnIp iputil.VpnIp
|
||||||
recvError atomic.Uint32
|
recvError atomic.Uint32
|
||||||
remoteCidr *cidr.Tree4
|
remoteCidr *cidr.Tree4[struct{}]
|
||||||
relayState RelayState
|
relayState RelayState
|
||||||
|
|
||||||
// HandshakePacket records the packets used to create this hostinfo
|
// HandshakePacket records the packets used to create this hostinfo
|
||||||
|
@ -633,7 +633,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteCidr := cidr.NewTree4()
|
remoteCidr := cidr.NewTree4[struct{}]()
|
||||||
for _, ip := range c.Details.Ips {
|
for _, ip := range c.Details.Ips {
|
||||||
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,7 +74,7 @@ type LightHouse struct {
|
||||||
// IP's of relays that can be used by peers to access me
|
// IP's of relays that can be used by peers to access me
|
||||||
relaysForMe atomic.Pointer[[]iputil.VpnIp]
|
relaysForMe atomic.Pointer[[]iputil.VpnIp]
|
||||||
|
|
||||||
calculatedRemotes atomic.Pointer[cidr.Tree4] // Maps VpnIp to []*calculatedRemote
|
calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
|
||||||
|
|
||||||
metrics *MessageMetrics
|
metrics *MessageMetrics
|
||||||
metricHolepunchTx metrics.Counter
|
metricHolepunchTx metrics.Counter
|
||||||
|
@ -166,7 +166,7 @@ func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp {
|
||||||
return *lh.relaysForMe.Load()
|
return *lh.relaysForMe.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4 {
|
func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] {
|
||||||
return lh.calculatedRemotes.Load()
|
return lh.calculatedRemotes.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -594,11 +594,10 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
|
||||||
if tree == nil {
|
if tree == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
value := tree.MostSpecificContains(vpnIp)
|
ok, calculatedRemotes := tree.MostSpecificContains(vpnIp)
|
||||||
if value == nil {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
calculatedRemotes := value.([]*calculatedRemote)
|
|
||||||
|
|
||||||
var calculated []*Ip4AndPort
|
var calculated []*Ip4AndPort
|
||||||
for _, cr := range calculatedRemotes {
|
for _, cr := range calculatedRemotes {
|
||||||
|
|
|
@ -21,8 +21,8 @@ type Route struct {
|
||||||
Install bool
|
Install bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) {
|
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) {
|
||||||
routeTree := cidr.NewTree4()
|
routeTree := cidr.NewTree4[iputil.VpnIp]()
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !allowMTU && r.MTU > 0 {
|
if !allowMTU && r.MTU > 0 {
|
||||||
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
|
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
|
||||||
|
|
|
@ -265,18 +265,16 @@ func Test_makeRouteTree(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
|
ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
|
||||||
r := routeTree.MostSpecificContains(ip)
|
ok, r := routeTree.MostSpecificContains(ip)
|
||||||
assert.NotNil(t, r)
|
assert.True(t, ok)
|
||||||
assert.IsType(t, iputil.VpnIp(0), r)
|
assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
|
||||||
assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
|
|
||||||
|
|
||||||
ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1"))
|
ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1"))
|
||||||
r = routeTree.MostSpecificContains(ip)
|
ok, r = routeTree.MostSpecificContains(ip)
|
||||||
assert.NotNil(t, r)
|
assert.True(t, ok)
|
||||||
assert.IsType(t, iputil.VpnIp(0), r)
|
assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
|
||||||
assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
|
|
||||||
|
|
||||||
ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1"))
|
ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1"))
|
||||||
r = routeTree.MostSpecificContains(ip)
|
ok, r = routeTree.MostSpecificContains(ip)
|
||||||
assert.Nil(t, r)
|
assert.False(t, ok)
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ type tun struct {
|
||||||
cidr *net.IPNet
|
cidr *net.IPNet
|
||||||
DefaultMTU int
|
DefaultMTU int
|
||||||
Routes []Route
|
Routes []Route
|
||||||
routeTree *cidr.Tree4
|
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||||
|
@ -304,9 +304,9 @@ func (t *tun) Activate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||||
r := t.routeTree.MostSpecificContains(ip)
|
ok, r := t.routeTree.MostSpecificContains(ip)
|
||||||
if r != nil {
|
if ok {
|
||||||
return r.(iputil.VpnIp)
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
|
@ -48,7 +48,7 @@ type tun struct {
|
||||||
cidr *net.IPNet
|
cidr *net.IPNet
|
||||||
MTU int
|
MTU int
|
||||||
Routes []Route
|
Routes []Route
|
||||||
routeTree *cidr.Tree4
|
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
|
@ -192,12 +192,8 @@ func (t *tun) Activate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||||
r := t.routeTree.MostSpecificContains(ip)
|
_, r := t.routeTree.MostSpecificContains(ip)
|
||||||
if r != nil {
|
return r
|
||||||
return r.(iputil.VpnIp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Cidr() *net.IPNet {
|
func (t *tun) Cidr() *net.IPNet {
|
||||||
|
|
|
@ -30,7 +30,7 @@ type tun struct {
|
||||||
TXQueueLen int
|
TXQueueLen int
|
||||||
|
|
||||||
Routes []Route
|
Routes []Route
|
||||||
routeTree atomic.Pointer[cidr.Tree4]
|
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||||
routeChan chan struct{}
|
routeChan chan struct{}
|
||||||
useSystemRoutes bool
|
useSystemRoutes bool
|
||||||
|
|
||||||
|
@ -154,12 +154,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||||
r := t.routeTree.Load().MostSpecificContains(ip)
|
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||||
if r != nil {
|
return r
|
||||||
return r.(iputil.VpnIp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Write(b []byte) (int, error) {
|
func (t *tun) Write(b []byte) (int, error) {
|
||||||
|
@ -380,7 +376,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newTree := cidr.NewTree4()
|
newTree := cidr.NewTree4[iputil.VpnIp]()
|
||||||
if r.Type == unix.RTM_NEWROUTE {
|
if r.Type == unix.RTM_NEWROUTE {
|
||||||
for _, oldR := range t.routeTree.Load().List() {
|
for _, oldR := range t.routeTree.Load().List() {
|
||||||
newTree.AddCIDR(oldR.CIDR, oldR.Value)
|
newTree.AddCIDR(oldR.CIDR, oldR.Value)
|
||||||
|
@ -392,7 +388,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||||
} else {
|
} else {
|
||||||
gw := iputil.Ip2VpnIp(r.Gw)
|
gw := iputil.Ip2VpnIp(r.Gw)
|
||||||
for _, oldR := range t.routeTree.Load().List() {
|
for _, oldR := range t.routeTree.Load().List() {
|
||||||
if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw {
|
if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw {
|
||||||
// This is the record to delete
|
// This is the record to delete
|
||||||
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
|
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -29,7 +29,7 @@ type tun struct {
|
||||||
cidr *net.IPNet
|
cidr *net.IPNet
|
||||||
MTU int
|
MTU int
|
||||||
Routes []Route
|
Routes []Route
|
||||||
routeTree *cidr.Tree4
|
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
|
@ -134,12 +134,8 @@ func (t *tun) Activate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||||
r := t.routeTree.MostSpecificContains(ip)
|
_, r := t.routeTree.MostSpecificContains(ip)
|
||||||
if r != nil {
|
return r
|
||||||
return r.(iputil.VpnIp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Cidr() *net.IPNet {
|
func (t *tun) Cidr() *net.IPNet {
|
||||||
|
|
|
@ -23,7 +23,7 @@ type tun struct {
|
||||||
cidr *net.IPNet
|
cidr *net.IPNet
|
||||||
MTU int
|
MTU int
|
||||||
Routes []Route
|
Routes []Route
|
||||||
routeTree *cidr.Tree4
|
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
|
@ -115,12 +115,8 @@ func (t *tun) Activate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||||
r := t.routeTree.MostSpecificContains(ip)
|
_, r := t.routeTree.MostSpecificContains(ip)
|
||||||
if r != nil {
|
return r
|
||||||
return r.(iputil.VpnIp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Cidr() *net.IPNet {
|
func (t *tun) Cidr() *net.IPNet {
|
||||||
|
|
|
@ -19,7 +19,7 @@ type TestTun struct {
|
||||||
Device string
|
Device string
|
||||||
cidr *net.IPNet
|
cidr *net.IPNet
|
||||||
Routes []Route
|
Routes []Route
|
||||||
routeTree *cidr.Tree4
|
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
|
||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
|
@ -83,12 +83,8 @@ func (t *TestTun) Get(block bool) []byte {
|
||||||
//********************************************************************************************************************//
|
//********************************************************************************************************************//
|
||||||
|
|
||||||
func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||||
r := t.routeTree.MostSpecificContains(ip)
|
_, r := t.routeTree.MostSpecificContains(ip)
|
||||||
if r != nil {
|
return r
|
||||||
return r.(iputil.VpnIp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestTun) Activate() error {
|
func (t *TestTun) Activate() error {
|
||||||
|
|
|
@ -18,7 +18,7 @@ type waterTun struct {
|
||||||
cidr *net.IPNet
|
cidr *net.IPNet
|
||||||
MTU int
|
MTU int
|
||||||
Routes []Route
|
Routes []Route
|
||||||
routeTree *cidr.Tree4
|
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||||
|
|
||||||
*water.Interface
|
*water.Interface
|
||||||
}
|
}
|
||||||
|
@ -97,12 +97,8 @@ func (t *waterTun) Activate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||||
r := t.routeTree.MostSpecificContains(ip)
|
_, r := t.routeTree.MostSpecificContains(ip)
|
||||||
if r != nil {
|
return r
|
||||||
return r.(iputil.VpnIp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *waterTun) Cidr() *net.IPNet {
|
func (t *waterTun) Cidr() *net.IPNet {
|
||||||
|
|
|
@ -24,7 +24,7 @@ type winTun struct {
|
||||||
prefix netip.Prefix
|
prefix netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes []Route
|
Routes []Route
|
||||||
routeTree *cidr.Tree4
|
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||||
|
|
||||||
tun *wintun.NativeTun
|
tun *wintun.NativeTun
|
||||||
}
|
}
|
||||||
|
@ -146,12 +146,8 @@ func (t *winTun) Activate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||||
r := t.routeTree.MostSpecificContains(ip)
|
_, r := t.routeTree.MostSpecificContains(ip)
|
||||||
if r != nil {
|
return r
|
||||||
return r.(iputil.VpnIp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTun) Cidr() *net.IPNet {
|
func (t *winTun) Cidr() *net.IPNet {
|
||||||
|
|
Loading…
Reference in New Issue