diff --git a/control_tester.go b/control_tester.go index 7852943..4fa0763 100644 --- a/control_tester.go +++ b/control_tester.go @@ -6,6 +6,8 @@ package nebula import ( "net" + "github.com/slackhq/nebula/cert" + "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/slackhq/nebula/header" @@ -14,7 +16,7 @@ import ( "github.com/slackhq/nebula/udp" ) -// WaitForTypeByIndex will pipe all messages from this control device into the pipeTo control device +// WaitForType will pipe all messages from this control device into the pipeTo control device // returning after a message matching the criteria has been piped func (c *Control) WaitForType(msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) { h := &header.H{} @@ -153,3 +155,11 @@ func (c *Control) KillPendingTunnel(vpnIp net.IP) bool { c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo) return true } + +func (c *Control) GetHostmap() *HostMap { + return c.f.hostMap +} + +func (c *Control) GetCert() *cert.NebulaCertificate { + return c.f.certState.certificate +} diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index b92d7e0..bfde43e 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -85,6 +85,7 @@ func TestGoodHandshake(t *testing.T) { defer r.RenderFlow() assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) + r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() //TODO: assert hostmaps @@ -150,6 +151,7 @@ func TestWrongResponderHandshake(t *testing.T) { //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //TODO: assert hostmaps for everyone + r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) t.Log("Success!") myControl.Stop() theirControl.Stop() @@ -205,6 +207,7 @@ func Test_Case1_Stage1Race(t *testing.T) { t.Log("Do a bidirectional tunnel test") assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) + r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() //TODO: assert hostmaps @@ -235,6 +238,7 @@ func TestRelays(t *testing.T) { p := r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIp, theirVpnIp, 80, 80) + r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it } diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go new file mode 100644 index 0000000..948281a --- /dev/null +++ b/e2e/router/hostmap.go @@ -0,0 +1,109 @@ +//go:build e2e_testing +// +build e2e_testing + +package router + +import ( + "fmt" + "strings" + + "github.com/slackhq/nebula" +) + +type edge struct { + from string + to string + dual bool +} + +func renderHostmaps(controls ...*nebula.Control) string { + var lines []*edge + r := "graph TB\n" + for _, c := range controls { + sr, se := renderHostmap(c) + r += sr + for _, e := range se { + add := true + + // Collapse duplicate edges into a bi-directionally connected edge + for _, ge := range lines { + if e.to == ge.from && e.from == ge.to { + add = false + ge.dual = true + break + } + } + + if add { + lines = append(lines, e) + } + } + } + + for _, line := range lines { + if line.dual { + r += fmt.Sprintf("\t%v <--> %v\n", line.from, line.to) + } else { + r += fmt.Sprintf("\t%v --> %v\n", line.from, line.to) + } + + } + + return r +} + +func renderHostmap(c *nebula.Control) (string, []*edge) { + var lines []string + var globalLines []*edge + + clusterName := strings.Trim(c.GetCert().Details.Name, " ") + clusterVpnIp := c.GetCert().Details.Ips[0].IP + r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp) + + hm := c.GetHostmap() + + // Draw the vpn to index nodes + r += fmt.Sprintf("\t\tsubgraph %s.hosts[\"Hosts (vpn ip to index)\"]\n", clusterName) + for vpnIp, hi := range hm.Hosts { + r += fmt.Sprintf("\t\t\t%v.%v[\"%v\"]\n", clusterName, vpnIp, vpnIp) + lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, vpnIp, clusterName, hi.GetLocalIndex())) + + rs := hi.GetRelayState() + for _, relayIp := range rs.CopyRelayIps() { + lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, vpnIp, clusterName, relayIp)) + } + + for _, relayIp := range rs.CopyRelayForIdxs() { + lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, vpnIp, clusterName, relayIp)) + } + } + r += "\t\tend\n" + + // Draw the relay hostinfos + if len(hm.Relays) > 0 { + r += fmt.Sprintf("\t\tsubgraph %s.relays[\"Relays (relay index to hostinfo)\"]\n", clusterName) + for relayIndex, hi := range hm.Relays { + r += fmt.Sprintf("\t\t\t%v.%v[\"%v\"]\n", clusterName, relayIndex, relayIndex) + lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, relayIndex, clusterName, hi.GetLocalIndex())) + } + r += "\t\tend\n" + } + + // Draw the local index to relay or remote index nodes + r += fmt.Sprintf("\t\tsubgraph indexes.%s[\"Indexes (index to hostinfo)\"]\n", clusterName) + for idx, hi := range hm.Indexes { + r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp()) + remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ") + globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) + _ = hi + } + r += "\t\tend\n" + + // Add the edges inside this host + for _, line := range lines { + r += fmt.Sprintf("\t\t%v\n", line) + } + + r += "\tend\n" + return r, globalLines +} diff --git a/e2e/router/router.go b/e2e/router/router.go index 7b916a0..aa56db8 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -40,7 +40,12 @@ type R struct { // A map of vpn ip to the nebula control it belongs to vpnControls map[iputil.VpnIp]*nebula.Control - flow []flowEntry + ignoreFlows []ignoreFlow + flow []flowEntry + + // A set of additional mermaid graphs to draw in the flow log markdown file + // Currently consisting only of hostmap renders + additionalGraphs []mermaidGraph // All interactions are locked to help serialize behavior sync.Mutex @@ -50,6 +55,24 @@ type R struct { t testing.TB } +type ignoreFlow struct { + tun NullBool + messageType header.MessageType + subType header.MessageSubType + //from + //to +} + +type mermaidGraph struct { + title string + content string +} + +type NullBool struct { + HasValue bool + IsTrue bool +} + type flowEntry struct { note string packet *packet @@ -98,6 +121,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { inNat: make(map[string]*nebula.Control), outNat: make(map[string]net.UDPAddr), flow: []flowEntry{}, + ignoreFlows: []ignoreFlow{}, fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())), t: t, cancelRender: cancel, @@ -219,15 +243,55 @@ func (r *R) renderFlow() { } fmt.Fprintf(f, - " %s%s%s: %s(%s), counter: %v\n", + " %s%s%s: %s(%s), index %v, counter: %v\n", strings.Replace(p.from.GetUDPAddr(), ":", "#58;", 1), line, strings.Replace(p.to.GetUDPAddr(), ":", "#58;", 1), - h.TypeName(), h.SubTypeName(), h.MessageCounter, + h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, ) } } fmt.Fprintln(f, "```") + + for _, g := range r.additionalGraphs { + fmt.Fprintf(f, "## %s\n", g.title) + fmt.Fprintln(f, "```mermaid") + fmt.Fprintln(f, g.content) + fmt.Fprintln(f, "```") + } +} + +// IgnoreFlow tells the router to stop recording future flows that matches the provided criteria. +// messageType and subType will target nebula underlay packets while tun will target nebula overlay packets +// NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered +func (r *R) IgnoreFlow(messageType header.MessageType, subType header.MessageSubType, tun NullBool) { + r.Lock() + defer r.Unlock() + r.ignoreFlows = append(r.ignoreFlows, ignoreFlow{ + tun, + messageType, + subType, + }) +} + +func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) { + r.Lock() + defer r.Unlock() + + s := renderHostmaps(controls...) + if len(r.additionalGraphs) > 0 { + lastGraph := r.additionalGraphs[len(r.additionalGraphs)-1] + if lastGraph.content == s && lastGraph.title == title { + // Ignore this rendering if it matches the last rendering added + // This is useful if you want to track rendering changes + return + } + } + + r.additionalGraphs = append(r.additionalGraphs, mermaidGraph{ + title: title, + content: s, + }) } // InjectFlow can be used to record packet flow if the test is handling the routing on its own. @@ -268,6 +332,24 @@ func (r *R) unlockedInjectFlow(from, to *nebula.Control, p *udp.Packet, tun bool return nil } + if len(r.ignoreFlows) > 0 { + var h header.H + err := h.Parse(p.Data) + if err != nil { + panic(err) + } + + for _, i := range r.ignoreFlows { + if !tun { + if i.messageType == h.Type && i.subType == h.Subtype { + return nil + } + } else if i.tun.HasValue && i.tun.IsTrue { + return nil + } + } + } + fp := &packet{ from: from, to: to, diff --git a/hostmap_tester.go b/hostmap_tester.go new file mode 100644 index 0000000..1d4323f --- /dev/null +++ b/hostmap_tester.go @@ -0,0 +1,24 @@ +//go:build e2e_testing +// +build e2e_testing + +package nebula + +// This file contains functions used to export information to the e2e testing framework + +import "github.com/slackhq/nebula/iputil" + +func (i *HostInfo) GetVpnIp() iputil.VpnIp { + return i.vpnIp +} + +func (i *HostInfo) GetLocalIndex() uint32 { + return i.localIndexId +} + +func (i *HostInfo) GetRemoteIndex() uint32 { + return i.remoteIndexId +} + +func (i *HostInfo) GetRelayState() RelayState { + return i.relayState +}