package service import ( "bytes" "context" "errors" "net" "testing" "time" "dario.cat/mergo" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e" "golang.org/x/sync/errgroup" "gopkg.in/yaml.v2" ) type m map[string]interface{} func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service { vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} copy(vpnIpNet.IP, udpIp) _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) caB, err := caCrt.MarshalToPEM() if err != nil { panic(err) } mc := m{ "pki": m{ "ca": string(caB), "cert": string(myPEM), "key": string(myPrivKey), }, //"tun": m{"disabled": true}, "firewall": m{ "outbound": []m{{ "proto": "any", "port": "any", "host": "any", }}, "inbound": []m{{ "proto": "any", "port": "any", "host": "any", }}, }, "timers": m{ "pending_deletion_interval": 2, "connection_alive_interval": 2, }, "handshakes": m{ "try_interval": "200ms", }, } if overrides != nil { err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) if err != nil { panic(err) } mc = overrides } cb, err := yaml.Marshal(mc) if err != nil { panic(err) } var c config.C if err := c.LoadString(string(cb)); err != nil { panic(err) } s, err := New(&c) if err != nil { panic(err) } return s } func TestService(t *testing.T) { ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{ "static_host_map": m{}, "lighthouse": m{ "am_lighthouse": true, }, "listen": m{ "host": "0.0.0.0", "port": 4243, }, }) b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{ "static_host_map": m{ "10.0.0.1": []string{"localhost:4243"}, }, "lighthouse": m{ "hosts": []string{"10.0.0.1"}, "interval": 1, }, }) ln, err := a.Listen("tcp", ":1234") if err != nil { t.Fatal(err) } var eg errgroup.Group eg.Go(func() error { conn, err := ln.Accept() if err != nil { return err } defer conn.Close() t.Log("accepted connection") if _, err := conn.Write([]byte("server msg")); err != nil { return err } t.Log("server: wrote message") data := make([]byte, 100) n, err := conn.Read(data) if err != nil { return err } data = data[:n] if !bytes.Equal(data, []byte("client msg")) { return errors.New("got invalid message from client") } t.Log("server: read message") return conn.Close() }) c, err := b.DialContext(context.Background(), "tcp", "10.0.0.1:1234") if err != nil { t.Fatal(err) } if _, err := c.Write([]byte("client msg")); err != nil { t.Fatal(err) } data := make([]byte, 100) n, err := c.Read(data) if err != nil { t.Fatal(err) } data = data[:n] if !bytes.Equal(data, []byte("server msg")) { t.Fatal("got invalid message from client") } if err := c.Close(); err != nil { t.Fatal(err) } if err := eg.Wait(); err != nil { t.Fatal(err) } }