diff --git a/cidr/tree4.go b/cidr/tree4.go index 72c5130..c5ebe54 100644 --- a/cidr/tree4.go +++ b/cidr/tree4.go @@ -144,7 +144,7 @@ func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) { type eachFunc[T any] func(T) bool -// EachContains will call a function, passing the value, for each entry until the function returns false or the search is complete +// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete // The final return value will be true if the provided function returned true func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool { bit := startbit diff --git a/examples/config.yml b/examples/config.yml index c0969e1..d4ef0fd 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -316,7 +316,7 @@ firewall: # The firewall is default deny. There is no way to write a deny rule. # Rules are comprised of a protocol, port, and one or more of host, group, or CIDR - # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) + # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr) # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available). # code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any` # proto: `any`, `tcp`, `udp`, or `icmp` @@ -325,6 +325,7 @@ firewall: # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass # cidr: a remote CIDR, `0.0.0.0/0` is any. # local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes. + # Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate. # ca_name: An issuing CA name # ca_sha: An issuing CA shasum @@ -346,3 +347,10 @@ firewall: groups: - laptop - home + + # Expose a subnet (unsafe route) to hosts with the group remote_client + # This example assume you have a subnet of 192.168.100.1/24 or larger encoded in the certificate + - port: 8080 + proto: tcp + group: remote_client + local_cidr: 192.168.100.1/24 diff --git a/firewall.go b/firewall.go index 642c107..c3cf7cf 100644 --- a/firewall.go +++ b/firewall.go @@ -58,7 +58,9 @@ type Firewall struct { DefaultTimeout time.Duration //linux: 600s // Used to ensure we don't emit local packets for ips we don't own - localIps *cidr.Tree4[struct{}] + localIps *cidr.Tree4[struct{}] + assignedCIDR *net.IPNet + hasSubnets bool rules string rulesVersion uint16 @@ -103,17 +105,22 @@ func newFirewallTable() *FirewallTable { } type FirewallCA struct { - Any *firewallLocalCIDR - CANames map[string]*firewallLocalCIDR - CAShas map[string]*firewallLocalCIDR + Any *FirewallRule + CANames map[string]*FirewallRule + CAShas map[string]*FirewallRule } type FirewallRule struct { // Any makes Hosts, Groups, and CIDR irrelevant - Any bool - Hosts map[string]struct{} - Groups [][]string - CIDR *cidr.Tree4[struct{}] + Any *firewallLocalCIDR + Hosts map[string]*firewallLocalCIDR + Groups []*firewallGroups + CIDR *cidr.Tree4[*firewallLocalCIDR] +} + +type firewallGroups struct { + Groups []string + LocalCIDR *firewallLocalCIDR } // Even though ports are uint16, int32 maps are faster for lookup @@ -121,8 +128,8 @@ type FirewallRule struct { type firewallPort map[int32]*FirewallCA type firewallLocalCIDR struct { - Any *FirewallRule - LocalCIDR *cidr.Tree4[*FirewallRule] + Any bool + LocalCIDR *cidr.Tree4[struct{}] } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. @@ -145,8 +152,15 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } localIps := cidr.NewTree4[struct{}]() + var assignedCIDR *net.IPNet for _, ip := range c.Details.Ips { - localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) + ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}} + localIps.AddCIDR(ipNet, struct{}{}) + + if assignedCIDR == nil { + // Only grabbing the first one in the cert since any more than that currently has undefined behavior + assignedCIDR = ipNet + } } for _, n := range c.Details.Subnets { @@ -164,6 +178,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D UDPTimeout: UDPTimeout, DefaultTimeout: defaultTimeout, localIps: localIps, + assignedCIDR: assignedCIDR, + hasSubnets: len(c.Details.Subnets) > 0, l: l, metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)), @@ -276,7 +292,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort return fmt.Errorf("unknown protocol %v", proto) } - return fp.addRule(startPort, endPort, groups, host, ip, localIp, caName, caSha) + return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha) } // GetRuleHash returns a hash representation of all inbound and outbound rules @@ -630,7 +646,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC return false } -func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { if startPort > endPort { return fmt.Errorf("start port was lower than end port") } @@ -638,12 +654,12 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, for i := startPort; i <= endPort; i++ { if _, ok := fp[i]; !ok { fp[i] = &FirewallCA{ - CANames: make(map[string]*firewallLocalCIDR), - CAShas: make(map[string]*firewallLocalCIDR), + CANames: make(map[string]*FirewallRule), + CAShas: make(map[string]*FirewallRule), } } - if err := fp[i].addRule(groups, host, ip, localIp, caName, caSha); err != nil { + if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil { return err } } @@ -674,26 +690,28 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer return fp[firewall.PortAny].match(p, c, caPool) } -func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error { - fl := func() *firewallLocalCIDR { - return &firewallLocalCIDR{ - LocalCIDR: cidr.NewTree4[*FirewallRule](), +func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error { + fr := func() *FirewallRule { + return &FirewallRule{ + Hosts: make(map[string]*firewallLocalCIDR), + Groups: make([]*firewallGroups, 0), + CIDR: cidr.NewTree4[*firewallLocalCIDR](), } } if caSha == "" && caName == "" { if fc.Any == nil { - fc.Any = fl() + fc.Any = fr() } - return fc.Any.addRule(groups, host, ip, localIp) + return fc.Any.addRule(f, groups, host, ip, localIp) } if caSha != "" { if _, ok := fc.CAShas[caSha]; !ok { - fc.CAShas[caSha] = fl() + fc.CAShas[caSha] = fr() } - err := fc.CAShas[caSha].addRule(groups, host, ip, localIp) + err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp) if err != nil { return err } @@ -701,9 +719,9 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN if caName != "" { if _, ok := fc.CANames[caName]; !ok { - fc.CANames[caName] = fl() + fc.CANames[caName] = fr() } - err := fc.CANames[caName].addRule(groups, host, ip, localIp) + err := fc.CANames[caName].addRule(f, groups, host, ip, localIp) if err != nil { return err } @@ -735,75 +753,56 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool return fc.CANames[s.Details.Name].match(p, c) } -func (fc *firewallLocalCIDR) addRule(groups []string, host string, ip, localIp *net.IPNet) error { - fr := func() *FirewallRule { - return &FirewallRule{ - Hosts: make(map[string]struct{}), - Groups: make([][]string, 0), - CIDR: cidr.NewTree4[struct{}](), +func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error { + flc := func() *firewallLocalCIDR { + return &firewallLocalCIDR{ + LocalCIDR: cidr.NewTree4[struct{}](), } } - if localIp == nil || (localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0))) { - if fc.Any == nil { - fc.Any = fr() - } - - return fc.Any.addRule(groups, host, ip) - } - - _, efr := fc.LocalCIDR.GetCIDR(localIp) - if efr != nil { - return efr.addRule(groups, host, ip) - } - - nfr := fr() - err := nfr.addRule(groups, host, ip) - if err != nil { - return err - } - - fc.LocalCIDR.AddCIDR(localIp, nfr) - return nil -} - -func (fc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool { - if fc == nil { - return false - } - - if fc.Any.match(p, c) { - return true - } - - return fc.LocalCIDR.EachContains(p.LocalIP, func(fr *FirewallRule) bool { - return fr.match(p, c) - }) -} - -func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error { - if fr.Any { - return nil - } - if fr.isAny(groups, host, ip) { - fr.Any = true - // If it's any we need to wipe out any pre-existing rules to save on memory - fr.Groups = make([][]string, 0) - fr.Hosts = make(map[string]struct{}) - fr.CIDR = cidr.NewTree4[struct{}]() - } else { - if len(groups) > 0 { - fr.Groups = append(fr.Groups, groups) + if fr.Any == nil { + fr.Any = flc() } - if host != "" { - fr.Hosts[host] = struct{}{} + return fr.Any.addRule(f, localCIDR) + } + + if len(groups) > 0 { + nlc := flc() + err := nlc.addRule(f, localCIDR) + if err != nil { + return err } - if ip != nil { - fr.CIDR.AddCIDR(ip, struct{}{}) + fr.Groups = append(fr.Groups, &firewallGroups{ + Groups: groups, + LocalCIDR: nlc, + }) + } + + if host != "" { + nlc := fr.Hosts[host] + if nlc == nil { + nlc = flc() } + err := nlc.addRule(f, localCIDR) + if err != nil { + return err + } + fr.Hosts[host] = nlc + } + + if ip != nil { + _, nlc := fr.CIDR.GetCIDR(ip) + if nlc == nil { + nlc = flc() + } + err := nlc.addRule(f, localCIDR) + if err != nil { + return err + } + fr.CIDR.AddCIDR(ip, nlc) } return nil @@ -837,7 +836,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool } // Shortcut path for if groups, hosts, or cidr contained an `any` - if fr.Any { + if fr.Any.match(p, c) { return true } @@ -845,7 +844,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool for _, sg := range fr.Groups { found := false - for _, g := range sg { + for _, g := range sg.Groups { if _, ok := c.Details.InvertedGroups[g]; !ok { found = false break @@ -854,26 +853,48 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool found = true } - if found { + if found && sg.LocalCIDR.match(p, c) { return true } } if fr.Hosts != nil { - if _, ok := fr.Hosts[c.Details.Name]; ok { - return true + if flc, ok := fr.Hosts[c.Details.Name]; ok { + if flc.match(p, c) { + return true + } } } - if fr.CIDR != nil { - ok, _ := fr.CIDR.Contains(p.RemoteIP) - if ok { - return true + return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool { + return flc.match(p, c) + }) +} + +func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error { + if localIp == nil || (localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0))) { + if !f.hasSubnets { + flc.Any = true + return nil } + localIp = f.assignedCIDR } - // No host, group, or cidr matched, bye bye - return false + flc.LocalCIDR.AddCIDR(localIp, struct{}{}) + return nil +} + +func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool { + if flc == nil { + return false + } + + if flc.Any { + return true + } + + ok, _ := flc.LocalCIDR.Contains(p.LocalIP) + return ok } type rule struct { diff --git a/firewall_test.go b/firewall_test.go index db31edf..7d65cb5 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -72,33 +72,32 @@ func TestFirewall_AddRule(t *testing.T) { assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) - assert.Empty(t, fw.InRules.TCP[1].Any.Any.Groups) - assert.Empty(t, fw.InRules.TCP[1].Any.Any.Hosts) + assert.Empty(t, fw.InRules.TCP[1].Any.Groups) + assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "")) - assert.False(t, fw.InRules.UDP[1].Any.Any.Any) - assert.Contains(t, fw.InRules.UDP[1].Any.Any.Groups[0], "g1") - assert.Empty(t, fw.InRules.UDP[1].Any.Any.Hosts) + assert.Nil(t, fw.InRules.UDP[1].Any.Any) + assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") + assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", "")) - assert.False(t, fw.InRules.ICMP[1].Any.Any.Any) - assert.Empty(t, fw.InRules.ICMP[1].Any.Any.Groups) - assert.Contains(t, fw.InRules.ICMP[1].Any.Any.Hosts, "h1") + assert.Nil(t, fw.InRules.ICMP[1].Any.Any) + assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) + assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", "")) - assert.False(t, fw.OutRules.AnyProto[1].Any.Any.Any) - ok, _ := fw.OutRules.AnyProto[1].Any.Any.CIDR.GetCIDR(ti) + assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) + ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", "")) - assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) - ok, fr := fw.OutRules.AnyProto[1].Any.LocalCIDR.GetCIDR(ti) + assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) + ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti) assert.True(t, ok) - assert.True(t, fr.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", "")) @@ -108,23 +107,6 @@ func TestFirewall_AddRule(t *testing.T) { assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") - // Set any and clear fields - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", "")) - ok, fr = fw.OutRules.AnyProto[0].Any.LocalCIDR.GetCIDR(ti) - assert.True(t, ok) - assert.False(t, fr.Any) - assert.Equal(t, []string{"g1", "g2"}, fr.Groups[0]) - assert.Contains(t, fr.Hosts, "h1") - - // run twice just to make sure - //TODO: these ANY rules should clear the CA firewall portion - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) - assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) - assert.Empty(t, fw.OutRules.AnyProto[0].Any.Any.Groups) - assert.Empty(t, fw.OutRules.AnyProto[0].Any.Any.Hosts) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) @@ -222,14 +204,15 @@ func TestFirewall_Drop(t *testing.T) { } func BenchmarkFirewallTable_match(b *testing.B) { + f := &Firewall{} ft := FirewallTable{ TCP: firewallPort{}, } _, n, _ := net.ParseCIDR("172.1.1.1/32") goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP) - _ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, nil, "", "") - _ = ft.TCP.addRule(100, 100, []string{"good-group"}, "good-host", nil, n, "", "") + _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "") + _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "") cp := cert.NewCAPool() b.Run("fail on proto", func(b *testing.B) {