mirror of https://github.com/slackhq/nebula.git
204 lines
3.5 KiB
Go
204 lines
3.5 KiB
Go
package cidr
|
|
|
|
import (
|
|
"net"
|
|
|
|
"github.com/slackhq/nebula/iputil"
|
|
)
|
|
|
|
type Node[T any] struct {
|
|
left *Node[T]
|
|
right *Node[T]
|
|
parent *Node[T]
|
|
hasValue bool
|
|
value T
|
|
}
|
|
|
|
type entry[T any] struct {
|
|
CIDR *net.IPNet
|
|
Value T
|
|
}
|
|
|
|
type Tree4[T any] struct {
|
|
root *Node[T]
|
|
list []entry[T]
|
|
}
|
|
|
|
const (
|
|
startbit = iputil.VpnIp(0x80000000)
|
|
)
|
|
|
|
func NewTree4[T any]() *Tree4[T] {
|
|
tree := new(Tree4[T])
|
|
tree.root = &Node[T]{}
|
|
tree.list = []entry[T]{}
|
|
return tree
|
|
}
|
|
|
|
func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
|
|
bit := startbit
|
|
node := tree.root
|
|
next := tree.root
|
|
|
|
ip := iputil.Ip2VpnIp(cidr.IP)
|
|
mask := iputil.Ip2VpnIp(cidr.Mask)
|
|
|
|
// Find our last ancestor in the tree
|
|
for bit&mask != 0 {
|
|
if ip&bit != 0 {
|
|
next = node.right
|
|
} else {
|
|
next = node.left
|
|
}
|
|
|
|
if next == nil {
|
|
break
|
|
}
|
|
|
|
bit = bit >> 1
|
|
node = next
|
|
}
|
|
|
|
// We already have this range so update the value
|
|
if next != nil {
|
|
addCIDR := cidr.String()
|
|
for i, v := range tree.list {
|
|
if addCIDR == v.CIDR.String() {
|
|
tree.list = append(tree.list[:i], tree.list[i+1:]...)
|
|
break
|
|
}
|
|
}
|
|
|
|
tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
|
|
node.value = val
|
|
node.hasValue = true
|
|
return
|
|
}
|
|
|
|
// Build up the rest of the tree we don't already have
|
|
for bit&mask != 0 {
|
|
next = &Node[T]{}
|
|
next.parent = node
|
|
|
|
if ip&bit != 0 {
|
|
node.right = next
|
|
} else {
|
|
node.left = next
|
|
}
|
|
|
|
bit >>= 1
|
|
node = next
|
|
}
|
|
|
|
// Final node marks our cidr, set the value
|
|
node.value = val
|
|
node.hasValue = true
|
|
tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
|
|
}
|
|
|
|
// Contains finds the first match, which may be the least specific
|
|
func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
|
|
bit := startbit
|
|
node := tree.root
|
|
|
|
for node != nil {
|
|
if node.hasValue {
|
|
return true, node.value
|
|
}
|
|
|
|
if ip&bit != 0 {
|
|
node = node.right
|
|
} else {
|
|
node = node.left
|
|
}
|
|
|
|
bit >>= 1
|
|
|
|
}
|
|
|
|
return false, value
|
|
}
|
|
|
|
// MostSpecificContains finds the most specific match
|
|
func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
|
|
bit := startbit
|
|
node := tree.root
|
|
|
|
for node != nil {
|
|
if node.hasValue {
|
|
value = node.value
|
|
ok = true
|
|
}
|
|
|
|
if ip&bit != 0 {
|
|
node = node.right
|
|
} else {
|
|
node = node.left
|
|
}
|
|
|
|
bit >>= 1
|
|
}
|
|
|
|
return ok, value
|
|
}
|
|
|
|
type eachFunc[T any] func(T) bool
|
|
|
|
// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete
|
|
// The final return value will be true if the provided function returned true
|
|
func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
|
|
bit := startbit
|
|
node := tree.root
|
|
|
|
for node != nil {
|
|
if node.hasValue {
|
|
// If the each func returns true then we can exit the loop
|
|
if each(node.value) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
if ip&bit != 0 {
|
|
node = node.right
|
|
} else {
|
|
node = node.left
|
|
}
|
|
|
|
bit >>= 1
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// GetCIDR returns the entry added by the most recent matching AddCIDR call
|
|
func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
|
|
bit := startbit
|
|
node := tree.root
|
|
|
|
ip := iputil.Ip2VpnIp(cidr.IP)
|
|
mask := iputil.Ip2VpnIp(cidr.Mask)
|
|
|
|
// Find our last ancestor in the tree
|
|
for node != nil && bit&mask != 0 {
|
|
if ip&bit != 0 {
|
|
node = node.right
|
|
} else {
|
|
node = node.left
|
|
}
|
|
|
|
bit = bit >> 1
|
|
}
|
|
|
|
if bit&mask == 0 && node != nil {
|
|
value = node.value
|
|
ok = node.hasValue
|
|
}
|
|
|
|
return ok, value
|
|
}
|
|
|
|
// List will return all CIDRs and their current values. Do not modify the contents!
|
|
func (tree *Tree4[T]) List() []entry[T] {
|
|
return tree.list
|
|
}
|