diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 9f239c2..263b240 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -18,8 +18,8 @@ import ( func TestGoodHandshake(t *testing.T) { ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}) - theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}) + myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) // Put their info in our lighthouse myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) @@ -70,9 +70,9 @@ func TestWrongResponderHandshake(t *testing.T) { // The IPs here are chosen on purpose: // The current remote handling will sort by preference, public, and then lexically. // So we need them to have a higher address than evil (we could apply a preference though) - myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}) - theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}) - evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}) + myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil) + theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil) + evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil) // Add their real udp addr, which should be tried after evil. myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) @@ -130,8 +130,8 @@ func TestWrongResponderHandshake(t *testing.T) { func Test_Case1_Stage1Race(t *testing.T) { ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}) - theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}) + myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) + theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) // Put their info in our lighthouse and vice versa myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) @@ -183,3 +183,113 @@ func Test_Case1_Stage1Race(t *testing.T) { } //TODO: add a test with many lies + +func TestPSK(t *testing.T) { + tests := []struct { + name string + myPskMode nebula.PskMode + theirPskMode nebula.PskMode + }{ + { + name: "none to transitional", + myPskMode: nebula.PskNone, + theirPskMode: nebula.PskTransitional, + }, + { + name: "transitional to none", + myPskMode: nebula.PskTransitional, + theirPskMode: nebula.PskNone, + }, + { + name: "both transitional", + myPskMode: nebula.PskTransitional, + theirPskMode: nebula.PskTransitional, + }, + + { + name: "enforced to transitional", + myPskMode: nebula.PskEnforced, + theirPskMode: nebula.PskTransitional, + }, + { + name: "transitional to enforced", + myPskMode: nebula.PskTransitional, + theirPskMode: nebula.PskEnforced, + }, + { + name: "both enforced", + myPskMode: nebula.PskEnforced, + theirPskMode: nebula.PskEnforced, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var myPskSettings, theirPskSettings *m + + switch test.myPskMode { + case nebula.PskNone: + myPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "none"}}} + case nebula.PskTransitional: + myPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "transitional", "keys": []string{"this is a key"}}}} + case nebula.PskEnforced: + myPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key"}}}} + } + + switch test.theirPskMode { + case nebula.PskNone: + theirPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "none"}}} + case nebula.PskTransitional: + theirPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "transitional", "keys": []string{"this is a key"}}}} + case nebula.PskEnforced: + theirPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key"}}}} + } + + ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, myPskSettings) + theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, theirPskSettings) + + myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr) + r := router.NewR(myControl, theirControl) + + // Start the servers + myControl.Start() + theirControl.Start() + + t.Log("Route until we see our cached packet flow") + myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me")) + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + h := &header.H{} + err := h.Parse(p.Data) + if err != nil { + panic(err) + } + + // If this is the stage 1 handshake packet and I am configured to enforce psk, my cert name should not appear. + // It would likely be more obvious to unmarshal the payload + if test.myPskMode == nebula.PskEnforced && h.Type == 0 && h.MessageCounter == 1 { + assert.NotContains(t, string(p.Data), "test me") + } + + if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 { + return router.RouteAndExit + } + + return router.KeepRouting + }) + + t.Log("My cached packet should be received by them") + myCachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80) + + t.Log("Test the tunnel with them") + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) + assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() + //TODO: assert hostmaps + }) + } + +} diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 843e08c..414d752 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -15,6 +15,7 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/imdario/mergo" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" @@ -30,7 +31,7 @@ import ( type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP) (*nebula.Control, net.IP, *net.UDPAddr) { +func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, customConfig *m) (*nebula.Control, net.IP, *net.UDPAddr) { l := NewTestLogger() vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} @@ -40,7 +41,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u IP: udpIp, Port: 4242, } - _, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) + _, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, "test "+name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) caB, err := caCrt.MarshalToPEM() if err != nil { @@ -86,6 +87,24 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u c := config.NewC(l) c.LoadString(string(cb)) + if customConfig != nil { + ccb, err := yaml.Marshal(customConfig) + if err != nil { + panic(err) + } + + ccm := map[interface{}]interface{}{} + err = yaml.Unmarshal(ccb, &ccm) + if err != nil { + panic(err) + } + + err = mergo.Merge(&c.Settings, ccm, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + } + control, err := nebula.Main(c, false, "e2e-test", l, nil) if err != nil {