diff --git a/cert/cert.go b/cert/cert.go index fac72f9..e56b372 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -468,6 +468,63 @@ func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) { return json.Marshal(jc) } +//func (nc *NebulaCertificate) Copy() *NebulaCertificate { +// r, err := nc.Marshal() +// if err != nil { +// //TODO +// return nil +// } +// +// c, err := UnmarshalNebulaCertificate(r) +// return c +//} + +func (nc *NebulaCertificate) Copy() *NebulaCertificate { + c := &NebulaCertificate{ + Details: NebulaCertificateDetails{ + Name: nc.Details.Name, + Groups: make([]string, len(nc.Details.Groups)), + Ips: make([]*net.IPNet, len(nc.Details.Ips)), + Subnets: make([]*net.IPNet, len(nc.Details.Subnets)), + NotBefore: nc.Details.NotBefore, + NotAfter: nc.Details.NotAfter, + PublicKey: make([]byte, len(nc.Details.PublicKey)), + IsCA: nc.Details.IsCA, + Issuer: nc.Details.Issuer, + InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)), + }, + Signature: make([]byte, len(nc.Signature)), + } + + copy(c.Signature, nc.Signature) + copy(c.Details.Groups, nc.Details.Groups) + copy(c.Details.PublicKey, nc.Details.PublicKey) + + for i, p := range nc.Details.Ips { + c.Details.Ips[i] = &net.IPNet{ + IP: make(net.IP, len(p.IP)), + Mask: make(net.IPMask, len(p.Mask)), + } + copy(c.Details.Ips[i].IP, p.IP) + copy(c.Details.Ips[i].Mask, p.Mask) + } + + for i, p := range nc.Details.Subnets { + c.Details.Subnets[i] = &net.IPNet{ + IP: make(net.IP, len(p.IP)), + Mask: make(net.IPMask, len(p.Mask)), + } + copy(c.Details.Subnets[i].IP, p.IP) + copy(c.Details.Subnets[i].Mask, p.Mask) + } + + for g := range nc.Details.InvertedGroups { + c.Details.InvertedGroups[g] = struct{}{} + } + + return c +} + func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool { for _, net := range rootIps { if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) { diff --git a/cert/cert_test.go b/cert/cert_test.go index a647c0b..aff469c 100644 --- a/cert/cert_test.go +++ b/cert/cert_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/golang/protobuf/proto" + "github.com/slackhq/nebula/util" "github.com/stretchr/testify/assert" "golang.org/x/crypto/curve25519" "golang.org/x/crypto/ed25519" @@ -487,6 +488,17 @@ func TestMarshalingNebulaCertificateConsistency(t *testing.T) { assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) } +func TestNebulaCertificate_Copy(t *testing.T) { + ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + assert.Nil(t, err) + + c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + assert.Nil(t, err) + cc := c.Copy() + + util.AssertDeepCopyEqual(t, c, cc) +} + func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { pub, priv, err := ed25519.GenerateKey(rand.Reader) if before.IsZero() { @@ -498,11 +510,12 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups [] nc := &NebulaCertificate{ Details: NebulaCertificateDetails{ - Name: "test ca", - NotBefore: before, - NotAfter: after, - PublicKey: pub, - IsCA: true, + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: true, + InvertedGroups: make(map[string]struct{}), }, } @@ -544,17 +557,17 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips if len(ips) == 0 { ips = []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + {IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())}, + {IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())}, + {IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())}, } } if len(subnets) == 0 { subnets = []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + {IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())}, + {IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())}, + {IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())}, } } @@ -562,15 +575,16 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips nc := &NebulaCertificate{ Details: NebulaCertificateDetails{ - Name: "testing", - Ips: ips, - Subnets: subnets, - Groups: groups, - NotBefore: before, - NotAfter: after, - PublicKey: pub, - IsCA: false, - Issuer: issuer, + Name: "testing", + Ips: ips, + Subnets: subnets, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + Issuer: issuer, + InvertedGroups: make(map[string]struct{}), }, } diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 8b9b9ea..912f470 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -55,7 +55,7 @@ func main() { l := logrus.New() l.Out = os.Stdout - err = nebula.Main(config, *configTest, true, Build, l, nil, nil) + c, err := nebula.Main(config, *configTest, Build, l, nil) switch v := err.(type) { case nebula.ContextualError: @@ -66,5 +66,10 @@ func main() { os.Exit(1) } + if !*configTest { + c.Start() + c.ShutdownBlock() + } + os.Exit(0) } diff --git a/cmd/nebula-service/service.go b/cmd/nebula-service/service.go index c04e7d8..6e1dcd9 100644 --- a/cmd/nebula-service/service.go +++ b/cmd/nebula-service/service.go @@ -14,21 +14,16 @@ import ( var logger service.Logger type program struct { - exit chan struct{} configPath *string configTest *bool build string + control *nebula.Control } func (p *program) Start(s service.Service) error { - logger.Info("Nebula service starting.") - p.exit = make(chan struct{}) // Start should not block. - go p.run() - return nil -} + logger.Info("Nebula service starting.") -func (p *program) run() error { config := nebula.NewConfig() err := config.Load(*p.configPath) if err != nil { @@ -37,17 +32,22 @@ func (p *program) run() error { l := logrus.New() l.Out = os.Stdout - return nebula.Main(config, *p.configTest, true, Build, l, nil, nil) + p.control, err = nebula.Main(config, *p.configTest, Build, l, nil) + if err != nil { + return err + } + + p.control.Start() + return nil } func (p *program) Stop(s service.Service) error { logger.Info("Nebula service stopping.") - close(p.exit) + p.control.Stop() return nil } func doService(configPath *string, configTest *bool, build string, serviceFlag *string) { - if *configPath == "" { ex, err := os.Executable() if err != nil { diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index 5697592..b28fa13 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -49,7 +49,7 @@ func main() { l := logrus.New() l.Out = os.Stdout - err = nebula.Main(config, *configTest, true, Build, l, nil, nil) + c, err := nebula.Main(config, *configTest, Build, l, nil) switch v := err.(type) { case nebula.ContextualError: @@ -60,5 +60,10 @@ func main() { os.Exit(1) } + if !*configTest { + c.Start() + c.ShutdownBlock() + } + os.Exit(0) } diff --git a/control.go b/control.go new file mode 100644 index 0000000..e16d07d --- /dev/null +++ b/control.go @@ -0,0 +1,169 @@ +package nebula + +import ( + "net" + "os" + "os/signal" + "syscall" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" +) + +// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching +// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc + +type Control struct { + f *Interface + l *logrus.Logger +} + +type ControlHostInfo struct { + VpnIP net.IP `json:"vpnIp"` + LocalIndex uint32 `json:"localIndex"` + RemoteIndex uint32 `json:"remoteIndex"` + RemoteAddrs []udpAddr `json:"remoteAddrs"` + CachedPackets int `json:"cachedPackets"` + Cert *cert.NebulaCertificate `json:"cert"` + MessageCounter uint64 `json:"messageCounter"` + CurrentRemote udpAddr `json:"currentRemote"` +} + +// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() +func (c *Control) Start() { + c.f.run() +} + +// Stop signals nebula to shutdown, returns after the shutdown is complete +func (c *Control) Stop() { + //TODO: stop tun and udp routines, the lock on hostMap effectively does that though + //TODO: this is probably better as a function in ConnectionManager or HostMap directly + c.f.hostMap.Lock() + for _, h := range c.f.hostMap.Hosts { + if h.ConnectionState.ready { + c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) + c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote). + Debug("Sending close tunnel message") + } + } + c.f.hostMap.Unlock() + c.l.Info("Goodbye") +} + +// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled +func (c *Control) ShutdownBlock() { + sigChan := make(chan os.Signal) + signal.Notify(sigChan, syscall.SIGTERM) + signal.Notify(sigChan, syscall.SIGINT) + + rawSig := <-sigChan + sig := rawSig.String() + c.l.WithField("signal", sig).Info("Caught signal, shutting down") + c.Stop() +} + +// RebindUDPServer asks the UDP listener to rebind it's listener. Mainly used on mobile clients when interfaces change +func (c *Control) RebindUDPServer() { + _ = c.f.outside.Rebind() +} + +// ListHostmap returns details about the actual or pending (handshaking) hostmap +func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo { + var hm *HostMap + if pendingMap { + hm = c.f.handshakeManager.pendingHostMap + } else { + hm = c.f.hostMap + } + + hm.RLock() + hosts := make([]ControlHostInfo, len(hm.Hosts)) + i := 0 + for _, v := range hm.Hosts { + hosts[i] = copyHostInfo(v) + i++ + } + hm.RUnlock() + + return hosts +} + +// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found +func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo { + var hm *HostMap + if pending { + hm = c.f.handshakeManager.pendingHostMap + } else { + hm = c.f.hostMap + } + + h, err := hm.QueryVpnIP(vpnIP) + if err != nil { + return nil + } + + ch := copyHostInfo(h) + return &ch +} + +// SetRemoteForTunnel forces a tunnel to use a specific remote +func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo { + hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP) + if err != nil { + return nil + } + + hostInfo.SetRemote(addr.Copy()) + ch := copyHostInfo(hostInfo) + return &ch +} + +// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. +func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool { + hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP) + if err != nil { + return false + } + + if !localOnly { + c.f.send( + closeTunnel, + 0, + hostInfo.ConnectionState, + hostInfo, + hostInfo.remote, + []byte{}, + make([]byte, 12, 12), + make([]byte, mtu), + ) + } + + c.f.closeTunnel(hostInfo) + return true +} + +func copyHostInfo(h *HostInfo) ControlHostInfo { + addrs := h.RemoteUDPAddrs() + chi := ControlHostInfo{ + VpnIP: int2ip(h.hostId), + LocalIndex: h.localIndexId, + RemoteIndex: h.remoteIndexId, + RemoteAddrs: make([]udpAddr, len(addrs), len(addrs)), + CachedPackets: len(h.packetStore), + MessageCounter: *h.ConnectionState.messageCounter, + } + + if c := h.GetCert(); c != nil { + chi.Cert = c.Copy() + } + + if h.remote != nil { + chi.CurrentRemote = *h.remote + } + + for i, addr := range addrs { + chi.RemoteAddrs[i] = addr.Copy() + } + + return chi +} diff --git a/control_test.go b/control_test.go new file mode 100644 index 0000000..f3ad7df --- /dev/null +++ b/control_test.go @@ -0,0 +1,111 @@ +package nebula + +import ( + "net" + "reflect" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/util" + "github.com/stretchr/testify/assert" +) + +func TestControl_GetHostInfoByVpnIP(t *testing.T) { + // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object + // To properly ensure we are not exposing core memory to the caller + hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0)) + remote1 := NewUDPAddr(100, 4444) + remote2 := NewUDPAddr(101, 4444) + ipNet := net.IPNet{ + IP: net.IPv4(1, 2, 3, 4), + Mask: net.IPMask{255, 255, 255, 0}, + } + + ipNet2 := net.IPNet{ + IP: net.IPv4(1, 2, 3, 5), + Mask: net.IPMask{255, 255, 255, 0}, + } + + crt := &cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "test", + Ips: []*net.IPNet{&ipNet}, + Subnets: []*net.IPNet{}, + Groups: []string{"default-group"}, + NotBefore: time.Unix(1, 0), + NotAfter: time.Unix(2, 0), + PublicKey: []byte{5, 6, 7, 8}, + IsCA: false, + Issuer: "the-issuer", + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + Signature: []byte{1, 2, 1, 2, 1, 3}, + } + counter := uint64(0) + + remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)} + hm.Add(ip2int(ipNet.IP), &HostInfo{ + remote: remote1, + Remotes: remotes, + ConnectionState: &ConnectionState{ + peerCert: crt, + messageCounter: &counter, + }, + remoteIndexId: 200, + localIndexId: 201, + hostId: ip2int(ipNet.IP), + }) + + hm.Add(ip2int(ipNet2.IP), &HostInfo{ + remote: remote1, + Remotes: remotes, + ConnectionState: &ConnectionState{ + peerCert: nil, + messageCounter: &counter, + }, + remoteIndexId: 200, + localIndexId: 201, + hostId: ip2int(ipNet2.IP), + }) + + c := Control{ + f: &Interface{ + hostMap: hm, + }, + l: logrus.New(), + } + + thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false) + + expectedInfo := ControlHostInfo{ + VpnIP: net.IPv4(1, 2, 3, 4).To4(), + LocalIndex: 201, + RemoteIndex: 200, + RemoteAddrs: []udpAddr{*remote1, *remote2}, + CachedPackets: 0, + Cert: crt.Copy(), + MessageCounter: 0, + CurrentRemote: *NewUDPAddr(100, 4444), + } + + // Make sure we don't have any unexpected fields + assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi) + util.AssertDeepCopyEqual(t, &expectedInfo, thi) + + // Make sure we don't panic if the host info doesn't have a cert yet + assert.NotPanics(t, func() { + thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false) + }) +} + +func assertFields(t *testing.T, expected []string, actualStruct interface{}) { + val := reflect.ValueOf(actualStruct).Elem() + fields := make([]string, val.NumField()) + for i := 0; i < val.NumField(); i++ { + fields[i] = val.Type().Field(i).Name + } + + assert.Equal(t, expected, fields) +} diff --git a/firewall.go b/firewall.go index 91638e1..42919fc 100644 --- a/firewall.go +++ b/firewall.go @@ -221,11 +221,17 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er // AddRule properly creates the in memory rule structure for a firewall table. func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { + // Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS + // https://github.com/golang/go/issues/14131 + sIp := "" + if ip != nil { + sIp = ip.String() + } // We need this rule string because we generate a hash. Removing this will break firewall reload. ruleString := fmt.Sprintf( "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s", - incoming, proto, startPort, endPort, groups, host, ip, caName, caSha, + incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha, ) f.rules += ruleString + "\n" @@ -233,7 +239,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort if !incoming { direction = "outgoing" } - l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": ip, "caName": caName, "caSha": caSha}). + l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}). Info("Firewall rule added") var ( diff --git a/go.mod b/go.mod index 324d3a5..32e36cc 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563 github.com/sirupsen/logrus v1.4.2 github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b - github.com/stretchr/testify v1.4.0 + github.com/stretchr/testify v1.6.1 github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 diff --git a/go.sum b/go.sum index 0449d4f..ace45be 100644 --- a/go.sum +++ b/go.sum @@ -103,8 +103,8 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao= github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k= @@ -112,8 +112,6 @@ github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17 golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY= -golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g= -golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo= golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -154,3 +152,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/interface.go b/interface.go index 95caa12..ee90657 100644 --- a/interface.go +++ b/interface.go @@ -35,7 +35,10 @@ type InterfaceConfig struct { DropLocalBroadcast bool DropMulticast bool UDPBatchSize int + udpQueues int + tunQueues int MessageMetrics *MessageMetrics + version string } type Interface struct { @@ -54,6 +57,8 @@ type Interface struct { dropLocalBroadcast bool dropMulticast bool udpBatchSize int + udpQueues int + tunQueues int version string metricHandshakes metrics.Histogram @@ -89,6 +94,9 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) { dropLocalBroadcast: c.DropLocalBroadcast, dropMulticast: c.DropMulticast, udpBatchSize: c.UDPBatchSize, + udpQueues: c.udpQueues, + tunQueues: c.tunQueues, + version: c.version, metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), messageMetrics: c.MessageMetrics, @@ -99,29 +107,28 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) { return ifce, nil } -func (f *Interface) Run(tunRoutines, udpRoutines int, buildVersion string) { +func (f *Interface) run() { // actually turn on tun dev if err := f.inside.Activate(); err != nil { l.Fatal(err) } - f.version = buildVersion addr, err := f.outside.LocalAddr() if err != nil { l.WithError(err).Error("Failed to get udp listen address") } l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()). - WithField("build", buildVersion).WithField("udpAddr", addr). + WithField("build", f.version).WithField("udpAddr", addr). Info("Nebula interface is active") // Launch n queues to read packets from udp - for i := 0; i < udpRoutines; i++ { + for i := 0; i < f.udpQueues; i++ { go f.listenOut(i) } // Launch n queues to read packets from tun dev - for i := 0; i < tunRoutines; i++ { + for i := 0; i < f.tunQueues; i++ { go f.listenIn(i) } } diff --git a/logger.go b/logger.go index c8c7fdf..fa42f19 100644 --- a/logger.go +++ b/logger.go @@ -1,6 +1,8 @@ package nebula import ( + "errors" + "github.com/sirupsen/logrus" ) @@ -15,10 +17,16 @@ func NewContextualError(msg string, fields map[string]interface{}, realError err } func (ce ContextualError) Error() string { + if ce.RealError == nil { + return ce.Context + } return ce.RealError.Error() } func (ce ContextualError) Unwrap() error { + if ce.RealError == nil { + return errors.New(ce.Context) + } return ce.RealError } diff --git a/main.go b/main.go index 1c7cbb8..73ec8e7 100644 --- a/main.go +++ b/main.go @@ -4,11 +4,8 @@ import ( "encoding/binary" "fmt" "net" - "os" - "os/signal" "strconv" "strings" - "syscall" "time" "github.com/sirupsen/logrus" @@ -21,12 +18,7 @@ var l = logrus.New() type m map[string]interface{} -type CommandRequest struct { - Command string - Callback chan error -} - -func Main(config *Config, configTest bool, block bool, buildVersion string, logger *logrus.Logger, tunFd *int, commandChan <-chan CommandRequest) error { +func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) { l = logger l.Formatter = &logrus.TextFormatter{ FullTimestamp: true, @@ -36,7 +28,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg if configTest { b, err := yaml.Marshal(config.Settings) if err != nil { - return err + return nil, err } // Print the final config @@ -45,7 +37,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg err := configLogger(config) if err != nil { - return NewContextualError("Failed to configure the logger", nil, err) + return nil, NewContextualError("Failed to configure the logger", nil, err) } config.RegisterReloadCallback(func(c *Config) { @@ -59,20 +51,20 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg trustedCAs, err = loadCAFromConfig(config) if err != nil { //The errors coming out of loadCA are already nicely formatted - return NewContextualError("Failed to load ca from config", nil, err) + return nil, NewContextualError("Failed to load ca from config", nil, err) } l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints") cs, err := NewCertStateFromConfig(config) if err != nil { //The errors coming out of NewCertStateFromConfig are already nicely formatted - return NewContextualError("Failed to load certificate from config", nil, err) + return nil, NewContextualError("Failed to load certificate from config", nil, err) } l.WithField("cert", cs.certificate).Debug("Client nebula certificate") fw, err := NewFirewallFromConfig(cs.certificate, config) if err != nil { - return NewContextualError("Error while loading firewall rules", nil, err) + return nil, NewContextualError("Error while loading firewall rules", nil, err) } l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") @@ -80,11 +72,11 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg tunCidr := cs.certificate.Details.Ips[0] routes, err := parseRoutes(config, tunCidr) if err != nil { - return NewContextualError("Could not parse tun.routes", nil, err) + return nil, NewContextualError("Could not parse tun.routes", nil, err) } unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr) if err != nil { - return NewContextualError("Could not parse tun.unsafe_routes", nil, err) + return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err) } ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) @@ -92,7 +84,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg if config.GetBool("sshd.enabled", false) { err = configSSH(ssh, config) if err != nil { - return NewContextualError("Error while configuring the sshd", nil, err) + return nil, NewContextualError("Error while configuring the sshd", nil, err) } } @@ -129,7 +121,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg } if err != nil { - return NewContextualError("Failed to get a tun/tap device", nil, err) + return nil, NewContextualError("Failed to get a tun/tap device", nil, err) } } @@ -140,28 +132,11 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg if !configTest { udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1) if err != nil { - return NewContextualError("Failed to open udp listener", nil, err) + return nil, NewContextualError("Failed to open udp listener", nil, err) } udpServer.reloadConfig(config) } - sigChan := make(chan os.Signal) - killChan := make(chan CommandRequest) - if commandChan != nil { - go func() { - cmd := CommandRequest{} - for { - cmd = <-commandChan - switch cmd.Command { - case "rebind": - udpServer.Rebind() - case "exit": - killChan <- cmd - } - } - }() - } - // Set up my internal host map var preferredRanges []*net.IPNet rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{}) @@ -170,7 +145,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg for _, rawPreferredRange := range rawPreferredRanges { _, preferredRange, err := net.ParseCIDR(rawPreferredRange) if err != nil { - return NewContextualError("Failed to parse preferred ranges", nil, err) + return nil, NewContextualError("Failed to parse preferred ranges", nil, err) } preferredRanges = append(preferredRanges, preferredRange) } @@ -183,7 +158,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg if rawLocalRange != "" { _, localRange, err := net.ParseCIDR(rawLocalRange) if err != nil { - return NewContextualError("Failed to parse local_range", nil, err) + return nil, NewContextualError("Failed to parse local_range", nil, err) } // Check if the entry for local_range was already specified in @@ -223,7 +198,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg if port == 0 && !configTest { uPort, err := udpServer.LocalAddr() if err != nil { - return NewContextualError("Failed to get listening port", nil, err) + return nil, NewContextualError("Failed to get listening port", nil, err) } port = int(uPort.Port) } @@ -240,10 +215,10 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg for i, host := range rawLighthouseHosts { ip := net.ParseIP(host) if ip == nil { - return NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) + return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) } if !tunCidr.Contains(ip) { - return NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) + return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) } lighthouseHosts[i] = ip2int(ip) } @@ -263,13 +238,13 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false) if err != nil { - return NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) + return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) } lightHouse.SetRemoteAllowList(remoteAllowList) localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true) if err != nil { - return NewContextualError("Invalid lighthouse.local_allow_list", nil, err) + return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err) } lightHouse.SetLocalAllowList(localAllowList) @@ -277,7 +252,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) { vpnIp := net.ParseIP(fmt.Sprintf("%v", k)) if !tunCidr.Contains(vpnIp) { - return NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil) + return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil) } vals, ok := v.([]interface{}) if ok { @@ -288,7 +263,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg ip := addr.IP port, err := strconv.Atoi(parts[1]) if err != nil { - return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) + return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) } lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true) } @@ -301,7 +276,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg ip := addr.IP port, err := strconv.Atoi(parts[1]) if err != nil { - return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) + return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) } lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true) } @@ -354,7 +329,10 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false), DropMulticast: config.GetBool("tun.drop_multicast", false), UDPBatchSize: config.GetInt("listen.batch", 64), + udpQueues: udpQueues, + tunQueues: config.GetInt("tun.routines", 1), MessageMetrics: messageMetrics, + version: buildVersion, } switch ifConfig.Cipher { @@ -363,14 +341,14 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg case "chachapoly": noiseEndianness = binary.LittleEndian default: - return fmt.Errorf("unknown cipher: %v", ifConfig.Cipher) + return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher) } var ifce *Interface if !configTest { ifce, err = NewInterface(ifConfig) if err != nil { - return fmt.Errorf("failed to initialize interface: %s", err) + return nil, fmt.Errorf("failed to initialize interface: %s", err) } ifce.RegisterConfigChangeCallbacks(config) @@ -381,18 +359,17 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg err = startStats(config, configTest) if err != nil { - return NewContextualError("Failed to start stats emitter", nil, err) + return nil, NewContextualError("Failed to start stats emitter", nil, err) } if configTest { - return nil + return nil, nil } //TODO: check if we _should_ be emitting stats go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10)) attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) - ifce.Run(config.GetInt("tun.routines", 1), udpQueues, buildVersion) // Start DNS server last to allow using the nebula IP as lighthouse.dns.host if amLighthouse && serveDns { @@ -400,47 +377,5 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg go dnsMain(hostMap, config) } - if block { - // Just sit here and be friendly, main thread. - shutdownBlock(ifce, sigChan, killChan) - } else { - // Even though we aren't blocking we still want to shutdown gracefully - go shutdownBlock(ifce, sigChan, killChan) - } - return nil -} - -func shutdownBlock(ifce *Interface, sigChan chan os.Signal, killChan chan CommandRequest) { - var cmd CommandRequest - var sig string - - signal.Notify(sigChan, syscall.SIGTERM) - signal.Notify(sigChan, syscall.SIGINT) - - select { - case rawSig := <-sigChan: - sig = rawSig.String() - case cmd = <-killChan: - sig = "controlling app" - } - - l.WithField("signal", sig).Info("Caught signal, shutting down") - - //TODO: stop tun and udp routines, the lock on hostMap effectively does that though - //TODO: this is probably better as a function in ConnectionManager or HostMap directly - ifce.hostMap.Lock() - for _, h := range ifce.hostMap.Hosts { - if h.ConnectionState.ready { - ifce.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) - l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote). - Debug("Sending close tunnel message") - } - } - ifce.hostMap.Unlock() - - l.WithField("signal", sig).Info("Goodbye") - select { - case cmd.Callback <- nil: - default: - } + return &Control{ifce, l}, nil } diff --git a/udp_android.go b/udp_android.go index 7b6fea5..ac5606c 100644 --- a/udp_android.go +++ b/udp_android.go @@ -31,6 +31,6 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *udpConn) Rebind() { - return +func (u *udpConn) Rebind() error { + return nil } diff --git a/udp_freebsd.go b/udp_freebsd.go index 88730be..5910a9d 100644 --- a/udp_freebsd.go +++ b/udp_freebsd.go @@ -33,6 +33,6 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *udpConn) Rebind() { - return +func (u *udpConn) Rebind() error { + return nil } diff --git a/udp_generic.go b/udp_generic.go index c59f8c1..94d8cdf 100644 --- a/udp_generic.go +++ b/udp_generic.go @@ -65,6 +65,17 @@ func (ua *udpAddr) Equals(t *udpAddr) bool { return ua.IP.Equal(t.IP) && ua.Port == t.Port } +func (ua *udpAddr) Copy() udpAddr { + nu := udpAddr{net.UDPAddr{ + Port: ua.Port, + Zone: ua.Zone, + IP: make(net.IP, len(ua.IP)), + }} + + copy(nu.IP, ua.IP) + return nu +} + func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error { _, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr) return err diff --git a/udp_linux.go b/udp_linux.go index 8166838..2cde08d 100644 --- a/udp_linux.go +++ b/udp_linux.go @@ -89,8 +89,12 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) { return &udpConn{sysFd: fd}, err } -func (u *udpConn) Rebind() { - return +func (u *udpConn) Rebind() error { + return nil +} + +func (ua *udpAddr) Copy() udpAddr { + return *ua } func (u *udpConn) SetRecvBuffer(n int) error { @@ -282,13 +286,6 @@ func (ua *udpAddr) Equals(t *udpAddr) bool { return ua.IP == t.IP && ua.Port == t.Port } -func (ua *udpAddr) Copy() *udpAddr { - return &udpAddr{ - Port: ua.Port, - IP: ua.IP, - } -} - func (ua *udpAddr) String() string { return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port) } diff --git a/udp_windows.go b/udp_windows.go index 463f79d..dcfe884 100644 --- a/udp_windows.go +++ b/udp_windows.go @@ -21,6 +21,6 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *udpConn) Rebind() { - return +func (u *udpConn) Rebind() error { + return nil } diff --git a/util/assert.go b/util/assert.go new file mode 100644 index 0000000..6f13d6b --- /dev/null +++ b/util/assert.go @@ -0,0 +1,130 @@ +package util + +import ( + "fmt" + "reflect" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory +// There is currently a special case for `time.loc` (as this code traverses into unexported fields) +func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) { + v1 := reflect.ValueOf(a) + v2 := reflect.ValueOf(b) + + if !assert.Equal(t, v1.Type(), v2.Type()) { + return + } + + traverseDeepCopy(t, v1, v2, v1.Type().String()) +} + +func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool { + switch v1.Kind() { + case reflect.Array: + for i := 0; i < v1.Len(); i++ { + if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) { + return false + } + } + return true + + case reflect.Slice: + if v1.IsNil() || v2.IsNil() { + return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil %+v, %+v", name, v1, v2) + } + + if !assert.Equal(t, v1.Len(), v2.Len(), "%s did not have the same length", name) { + return false + } + + // A slice with cap 0 + if v1.Cap() != 0 && !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same slice %v == %v", name, v1.Pointer(), v2.Pointer()) { + return false + } + + v1c := v1.Cap() + v2c := v2.Cap() + if v1c > 0 && v2c > 0 && v1.Slice(0, v1c).Slice(v1c-1, v1c-1).Pointer() == v2.Slice(0, v2c).Slice(v2c-1, v2c-1).Pointer() { + return assert.Fail(t, "", "%s share some underlying memory", name) + } + + for i := 0; i < v1.Len(); i++ { + if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) { + return false + } + } + return true + + case reflect.Interface: + if v1.IsNil() || v2.IsNil() { + return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name) + } + return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name) + + case reflect.Ptr: + local := reflect.ValueOf(time.Local).Pointer() + if local == v1.Pointer() && local == v2.Pointer() { + return true + } + + if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s points to the same memory", name) { + return false + } + + return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name) + + case reflect.Struct: + for i, n := 0, v1.NumField(); i < n; i++ { + if !traverseDeepCopy(t, v1.Field(i), v2.Field(i), name+"."+v1.Type().Field(i).Name) { + return false + } + } + return true + + case reflect.Map: + if v1.IsNil() || v2.IsNil() { + return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name) + } + + if !assert.Equal(t, v1.Len(), v2.Len(), "%s are not the same length", name) { + return false + } + + if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same memory", name) { + return false + } + + for _, k := range v1.MapKeys() { + val1 := v1.MapIndex(k) + val2 := v2.MapIndex(k) + if !assert.True(t, val1.IsValid(), "%s is an invalid key in %s", k, name) { + return false + } + + if !assert.True(t, val2.IsValid(), "%s is an invalid key in %s", k, name) { + return false + } + + if !traverseDeepCopy(t, val1, val2, name+fmt.Sprintf("%s[%s]", name, k)) { + return false + } + } + + return true + + default: + if v1.CanInterface() && v2.CanInterface() { + return assert.Equal(t, v1.Interface(), v2.Interface(), "%s was not equal", name) + } + + e1 := reflect.NewAt(v1.Type(), unsafe.Pointer(v1.UnsafeAddr())).Elem().Interface() + e2 := reflect.NewAt(v2.Type(), unsafe.Pointer(v2.UnsafeAddr())).Elem().Interface() + + return assert.Equal(t, e1, e2, "%s (unexported) was not equal", name) + } +}