mirror of https://github.com/slackhq/nebula.git
190 lines
2.8 KiB
Go
190 lines
2.8 KiB
Go
package cidr
|
|
|
|
import (
|
|
"net"
|
|
|
|
"github.com/slackhq/nebula/iputil"
|
|
)
|
|
|
|
const startbit6 = uint64(1 << 63)
|
|
|
|
type Tree6[T any] struct {
|
|
root4 *Node[T]
|
|
root6 *Node[T]
|
|
}
|
|
|
|
func NewTree6[T any]() *Tree6[T] {
|
|
tree := new(Tree6[T])
|
|
tree.root4 = &Node[T]{}
|
|
tree.root6 = &Node[T]{}
|
|
return tree
|
|
}
|
|
|
|
func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) {
|
|
var node, next *Node[T]
|
|
|
|
cidrIP, ipv4 := isIPV4(cidr.IP)
|
|
if ipv4 {
|
|
node = tree.root4
|
|
next = tree.root4
|
|
|
|
} else {
|
|
node = tree.root6
|
|
next = tree.root6
|
|
}
|
|
|
|
for i := 0; i < len(cidrIP); i += 4 {
|
|
ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
|
|
mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
|
|
bit := startbit
|
|
|
|
// Find our last ancestor in the tree
|
|
for bit&mask != 0 {
|
|
if ip&bit != 0 {
|
|
next = node.right
|
|
} else {
|
|
next = node.left
|
|
}
|
|
|
|
if next == nil {
|
|
break
|
|
}
|
|
|
|
bit = bit >> 1
|
|
node = next
|
|
}
|
|
|
|
// Build up the rest of the tree we don't already have
|
|
for bit&mask != 0 {
|
|
next = &Node[T]{}
|
|
next.parent = node
|
|
|
|
if ip&bit != 0 {
|
|
node.right = next
|
|
} else {
|
|
node.left = next
|
|
}
|
|
|
|
bit >>= 1
|
|
node = next
|
|
}
|
|
}
|
|
|
|
// Final node marks our cidr, set the value
|
|
node.value = val
|
|
node.hasValue = true
|
|
}
|
|
|
|
// Finds the most specific match
|
|
func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) {
|
|
var node *Node[T]
|
|
|
|
wholeIP, ipv4 := isIPV4(ip)
|
|
if ipv4 {
|
|
node = tree.root4
|
|
} else {
|
|
node = tree.root6
|
|
}
|
|
|
|
for i := 0; i < len(wholeIP); i += 4 {
|
|
ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
|
|
bit := startbit
|
|
|
|
for node != nil {
|
|
if node.hasValue {
|
|
value = node.value
|
|
ok = true
|
|
}
|
|
|
|
if bit == 0 {
|
|
break
|
|
}
|
|
|
|
if ip&bit != 0 {
|
|
node = node.right
|
|
} else {
|
|
node = node.left
|
|
}
|
|
|
|
bit >>= 1
|
|
}
|
|
}
|
|
|
|
return ok, value
|
|
}
|
|
|
|
func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) {
|
|
bit := startbit
|
|
node := tree.root4
|
|
|
|
for node != nil {
|
|
if node.hasValue {
|
|
value = node.value
|
|
ok = true
|
|
}
|
|
|
|
if ip&bit != 0 {
|
|
node = node.right
|
|
} else {
|
|
node = node.left
|
|
}
|
|
|
|
bit >>= 1
|
|
}
|
|
|
|
return ok, value
|
|
}
|
|
|
|
func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) {
|
|
ip := hi
|
|
node := tree.root6
|
|
|
|
for i := 0; i < 2; i++ {
|
|
bit := startbit6
|
|
|
|
for node != nil {
|
|
if node.hasValue {
|
|
value = node.value
|
|
ok = true
|
|
}
|
|
|
|
if bit == 0 {
|
|
break
|
|
}
|
|
|
|
if ip&bit != 0 {
|
|
node = node.right
|
|
} else {
|
|
node = node.left
|
|
}
|
|
|
|
bit >>= 1
|
|
}
|
|
|
|
ip = lo
|
|
}
|
|
|
|
return ok, value
|
|
}
|
|
|
|
func isIPV4(ip net.IP) (net.IP, bool) {
|
|
if len(ip) == net.IPv4len {
|
|
return ip, true
|
|
}
|
|
|
|
if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
|
|
return ip[12:16], true
|
|
}
|
|
|
|
return ip, false
|
|
}
|
|
|
|
func isZeros(p net.IP) bool {
|
|
for i := 0; i < len(p); i++ {
|
|
if p[i] != 0 {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|