nebula/service/service_test.go

166 lines
3.2 KiB
Go
Raw Normal View History

package service
import (
"bytes"
"context"
"errors"
"net"
"testing"
"time"
"dario.cat/mergo"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2"
)
type m map[string]interface{}
func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service {
vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
copy(vpnIpNet.IP, udpIp)
_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
caB, err := caCrt.MarshalToPEM()
if err != nil {
panic(err)
}
mc := m{
"pki": m{
"ca": string(caB),
"cert": string(myPEM),
"key": string(myPrivKey),
},
//"tun": m{"disabled": true},
"firewall": m{
"outbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
}},
},
"timers": m{
"pending_deletion_interval": 2,
"connection_alive_interval": 2,
},
"handshakes": m{
"try_interval": "200ms",
},
}
if overrides != nil {
err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = overrides
}
cb, err := yaml.Marshal(mc)
if err != nil {
panic(err)
}
var c config.C
if err := c.LoadString(string(cb)); err != nil {
panic(err)
}
s, err := New(&c)
if err != nil {
panic(err)
}
return s
}
func TestService(t *testing.T) {
ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{
"static_host_map": m{},
"lighthouse": m{
"am_lighthouse": true,
},
"listen": m{
"host": "0.0.0.0",
"port": 4243,
},
})
b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{
"static_host_map": m{
"10.0.0.1": []string{"localhost:4243"},
},
"lighthouse": m{
"hosts": []string{"10.0.0.1"},
"interval": 1,
},
})
ln, err := a.Listen("tcp", ":1234")
if err != nil {
t.Fatal(err)
}
var eg errgroup.Group
eg.Go(func() error {
conn, err := ln.Accept()
if err != nil {
return err
}
defer conn.Close()
t.Log("accepted connection")
if _, err := conn.Write([]byte("server msg")); err != nil {
return err
}
t.Log("server: wrote message")
data := make([]byte, 100)
n, err := conn.Read(data)
if err != nil {
return err
}
data = data[:n]
if !bytes.Equal(data, []byte("client msg")) {
return errors.New("got invalid message from client")
}
t.Log("server: read message")
return conn.Close()
})
c, err := b.DialContext(context.Background(), "tcp", "10.0.0.1:1234")
if err != nil {
t.Fatal(err)
}
if _, err := c.Write([]byte("client msg")); err != nil {
t.Fatal(err)
}
data := make([]byte, 100)
n, err := c.Read(data)
if err != nil {
t.Fatal(err)
}
data = data[:n]
if !bytes.Equal(data, []byte("server msg")) {
t.Fatal("got invalid message from client")
}
if err := c.Close(); err != nil {
t.Fatal(err)
}
if err := eg.Wait(); err != nil {
t.Fatal(err)
}
}