From 1083279a452ed1593805734df33a6beb865d4c6e Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 21 Nov 2023 08:50:18 -0800 Subject: [PATCH] add gvisor based service library (#965) * add service/ library --- control.go | 10 ++ e2e/handshakes_test.go | 36 +++--- e2e/helpers.go | 118 +++++++++++++++++ e2e/helpers_test.go | 111 +--------------- examples/go_service/main.go | 100 +++++++++++++++ go.mod | 6 +- go.sum | 16 ++- main.go | 10 +- overlay/tun.go | 40 ++++-- overlay/user.go | 63 +++++++++ service/listener.go | 36 ++++++ service/service.go | 248 ++++++++++++++++++++++++++++++++++++ service/service_test.go | 165 ++++++++++++++++++++++++ 13 files changed, 812 insertions(+), 147 deletions(-) create mode 100644 e2e/helpers.go create mode 100644 examples/go_service/main.go create mode 100644 overlay/user.go create mode 100644 service/listener.go create mode 100644 service/service.go create mode 100644 service/service_test.go diff --git a/control.go b/control.go index 13b2658..1e27b0f 100644 --- a/control.go +++ b/control.go @@ -11,6 +11,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) @@ -29,6 +30,7 @@ type controlHostLister interface { type Control struct { f *Interface l *logrus.Logger + ctx context.Context cancel context.CancelFunc sshStart func() statsStart func() @@ -71,6 +73,10 @@ func (c *Control) Start() { c.f.run() } +func (c *Control) Context() context.Context { + return c.ctx +} + // Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete func (c *Control) Stop() { // Stop the handshakeManager (and other services), to prevent new tunnels from @@ -226,6 +232,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { return } +func (c *Control) Device() overlay.Device { + return c.f.inside +} + func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { chi := ControlHostInfo{ diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 022b5a3..59f1d0e 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -20,7 +20,7 @@ import ( ) func BenchmarkHotPath(b *testing.B) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) @@ -44,7 +44,7 @@ func BenchmarkHotPath(b *testing.B) { } func TestGoodHandshake(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) @@ -95,7 +95,7 @@ func TestGoodHandshake(t *testing.T) { } func TestWrongResponderHandshake(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) // The IPs here are chosen on purpose: // The current remote handling will sort by preference, public, and then lexically. @@ -164,7 +164,7 @@ func TestStage1Race(t *testing.T) { // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) @@ -241,7 +241,7 @@ func TestStage1Race(t *testing.T) { } func TestUncleanShutdownRaceLoser(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) @@ -290,7 +290,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { } func TestUncleanShutdownRaceWinner(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) @@ -341,7 +341,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { } func TestRelays(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) @@ -372,7 +372,7 @@ func TestRelays(t *testing.T) { func TestStage1RaceRelays(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) @@ -421,7 +421,7 @@ func TestStage1RaceRelays(t *testing.T) { func TestStage1RaceRelays2(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) @@ -508,7 +508,7 @@ func TestStage1RaceRelays2(t *testing.T) { ////TODO: assert hostmaps } func TestRehandshakingRelays(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) @@ -538,7 +538,7 @@ func TestRehandshakingRelays(t *testing.T) { // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalToPEM() if err != nil { @@ -612,7 +612,7 @@ func TestRehandshakingRelays(t *testing.T) { func TestRehandshakingRelaysPrimary(t *testing.T) { // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) @@ -642,7 +642,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalToPEM() if err != nil { @@ -715,7 +715,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } func TestRehandshaking(t *testing.T) { - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) @@ -737,7 +737,7 @@ func TestRehandshaking(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew my certificate and spin until their sees it") - _, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalToPEM() if err != nil { @@ -811,7 +811,7 @@ func TestRehandshaking(t *testing.T) { func TestRehandshakingLoser(t *testing.T) { // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) @@ -837,7 +837,7 @@ func TestRehandshakingLoser(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew their certificate and spin until mine sees it") - _, _, theirNextPrivKey, theirNextPEM := newTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) + _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) caB, err := ca.MarshalToPEM() if err != nil { @@ -912,7 +912,7 @@ func TestRaceRegression(t *testing.T) { // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo - ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) diff --git a/e2e/helpers.go b/e2e/helpers.go new file mode 100644 index 0000000..13146ab --- /dev/null +++ b/e2e/helpers.go @@ -0,0 +1,118 @@ +package e2e + +import ( + "crypto/rand" + "io" + "net" + "time" + + "github.com/slackhq/nebula/cert" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ed25519" +) + +// NewTestCaCert will generate a CA cert +func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + nc := &cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: true, + InvertedGroups: make(map[string]struct{}), + }, + } + + if len(ips) > 0 { + nc.Details.Ips = ips + } + + if len(subnets) > 0 { + nc.Details.Subnets = subnets + } + + if len(groups) > 0 { + nc.Details.Groups = groups + } + + err = nc.Sign(cert.Curve_CURVE25519, priv) + if err != nil { + panic(err) + } + + pem, err := nc.MarshalToPEM() + if err != nil { + panic(err) + } + + return nc, pub, priv, pem +} + +// NewTestCert will generate a signed certificate with the provided details. +// Expiry times are defaulted if you do not pass them in +func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { + issuer, err := ca.Sha256Sum() + if err != nil { + panic(err) + } + + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + pub, rawPriv := x25519Keypair() + + nc := &cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: name, + Ips: []*net.IPNet{ip}, + Subnets: subnets, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + Issuer: issuer, + InvertedGroups: make(map[string]struct{}), + }, + } + + err = nc.Sign(ca.Details.Curve, key) + if err != nil { + panic(err) + } + + pem, err := nc.MarshalToPEM() + if err != nil { + panic(err) + } + + return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem +} + +func x25519Keypair() ([]byte, []byte) { + privkey := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, privkey); err != nil { + panic(err) + } + + pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) + if err != nil { + panic(err) + } + + return pubkey, privkey +} diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 8440a72..b05c84a 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -4,7 +4,6 @@ package e2e import ( - "crypto/rand" "fmt" "io" "net" @@ -22,8 +21,6 @@ import ( "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/ed25519" "gopkg.in/yaml.v2" ) @@ -40,7 +37,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u IP: udpIp, Port: 4242, } - _, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) + _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) caB, err := caCrt.MarshalToPEM() if err != nil { @@ -108,112 +105,6 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u return control, vpnIpNet, &udpAddr, c } -// newTestCaCert will generate a CA cert -func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { - pub, priv, err := ed25519.GenerateKey(rand.Reader) - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - nc := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - InvertedGroups: make(map[string]struct{}), - }, - } - - if len(ips) > 0 { - nc.Details.Ips = ips - } - - if len(subnets) > 0 { - nc.Details.Subnets = subnets - } - - if len(groups) > 0 { - nc.Details.Groups = groups - } - - err = nc.Sign(cert.Curve_CURVE25519, priv) - if err != nil { - panic(err) - } - - pem, err := nc.MarshalToPEM() - if err != nil { - panic(err) - } - - return nc, pub, priv, pem -} - -// newTestCert will generate a signed certificate with the provided details. -// Expiry times are defaulted if you do not pass them in -func newTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { - issuer, err := ca.Sha256Sum() - if err != nil { - panic(err) - } - - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - pub, rawPriv := x25519Keypair() - - nc := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: name, - Ips: []*net.IPNet{ip}, - Subnets: subnets, - Groups: groups, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: false, - Issuer: issuer, - InvertedGroups: make(map[string]struct{}), - }, - } - - err = nc.Sign(ca.Details.Curve, key) - if err != nil { - panic(err) - } - - pem, err := nc.MarshalToPEM() - if err != nil { - panic(err) - } - - return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem -} - -func x25519Keypair() ([]byte, []byte) { - privkey := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, privkey); err != nil { - panic(err) - } - - pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) - if err != nil { - panic(err) - } - - return pubkey, privkey -} - type doneCb func() func deadline(t *testing.T, seconds time.Duration) doneCb { diff --git a/examples/go_service/main.go b/examples/go_service/main.go new file mode 100644 index 0000000..f46273a --- /dev/null +++ b/examples/go_service/main.go @@ -0,0 +1,100 @@ +package main + +import ( + "bufio" + "fmt" + "log" + + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/service" +) + +func main() { + if err := run(); err != nil { + log.Fatalf("%+v", err) + } +} + +func run() error { + configStr := ` +tun: + user: true + +static_host_map: + '192.168.100.1': ['localhost:4242'] + +listen: + host: 0.0.0.0 + port: 4241 + +lighthouse: + am_lighthouse: false + interval: 60 + hosts: + - '192.168.100.1' + +firewall: + outbound: + # Allow all outbound traffic from this node + - port: any + proto: any + host: any + + inbound: + # Allow icmp between any nebula hosts + - port: any + proto: icmp + host: any + - port: any + proto: any + host: any + +pki: + ca: /home/rice/Developer/nebula-config/ca.crt + cert: /home/rice/Developer/nebula-config/app.crt + key: /home/rice/Developer/nebula-config/app.key +` + var config config.C + if err := config.LoadString(configStr); err != nil { + return err + } + service, err := service.New(&config) + if err != nil { + return err + } + + ln, err := service.Listen("tcp", ":1234") + if err != nil { + return err + } + for { + conn, err := ln.Accept() + if err != nil { + log.Printf("accept error: %s", err) + break + } + defer conn.Close() + + log.Printf("got connection") + + conn.Write([]byte("hello world\n")) + + scanner := bufio.NewScanner(conn) + for scanner.Scan() { + message := scanner.Text() + fmt.Fprintf(conn, "echo: %q\n", message) + log.Printf("got message %q", message) + } + + if err := scanner.Err(); err != nil { + log.Printf("scanner error: %s", err) + break + } + } + + service.Close() + if err := service.Wait(); err != nil { + return err + } + return nil +} diff --git a/go.mod b/go.mod index 4bef77c..f9f237f 100644 --- a/go.mod +++ b/go.mod @@ -19,10 +19,11 @@ require ( github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.8.4 - github.com/vishvananda/netlink v1.1.0 + github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 golang.org/x/crypto v0.14.0 golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 golang.org/x/net v0.17.0 + golang.org/x/sync v0.3.0 golang.org/x/sys v0.14.0 golang.org/x/term v0.13.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 @@ -30,6 +31,7 @@ require ( golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/protobuf v1.31.0 gopkg.in/yaml.v2 v2.4.0 + gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f ) require ( @@ -37,6 +39,7 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/protobuf v1.5.3 // indirect + github.com/google/btree v1.0.1 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect @@ -44,6 +47,7 @@ require ( github.com/prometheus/procfs v0.11.1 // indirect github.com/vishvananda/netns v0.0.4 // indirect golang.org/x/mod v0.12.0 // indirect + golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect golang.org/x/tools v0.13.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 08f5378..5e13be3 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= +github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -135,9 +137,9 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0= -github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= -github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= +github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 h1:8mhqcHPqTMhSPoslhGYihEgSfc77+7La1P6kiB6+9So= +github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= +github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -177,16 +179,18 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -201,6 +205,8 @@ golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= @@ -244,3 +250,5 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f h1:8GE2MRjGiFmfpon8dekPI08jEuNMQzSffVHgdupcO4E= +gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f/go.mod h1:pzr6sy8gDLfVmDAg8OYrlKvGEHw5C3PGTiBXBTCx76Q= diff --git a/main.go b/main.go index 14696ac..26f47eb 100644 --- a/main.go +++ b/main.go @@ -18,7 +18,7 @@ import ( type m map[string]interface{} -func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) { +func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { ctx, cancel := context.WithCancel(context.Background()) // Automatically cancel the context if Main returns an error, to signal all created goroutines to quit. defer func() { @@ -128,7 +128,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if !configTest { c.CatchHUP(ctx) - tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines) + if deviceFactory == nil { + deviceFactory = overlay.NewDeviceFromConfig + } + + tun, err = deviceFactory(c, l, tunCidr, routines) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } @@ -159,6 +163,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } for i := 0; i < routines; i++ { + l.Infof("listening %q %d", listenHost.IP, port) udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) @@ -335,6 +340,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return &Control{ ifce, l, + ctx, cancel, sshStart, statsStart, diff --git a/overlay/tun.go b/overlay/tun.go index 5eccec9..ca1a64a 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -10,7 +10,9 @@ import ( const DefaultMTU = 1300 -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *int, routines int) (Device, error) { +type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) + +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { routes, err := parseRoutes(c, tunCidr) if err != nil { return nil, util.NewContextualError("Could not parse tun.routes", nil, err) @@ -27,17 +29,6 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd * tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) return tun, nil - case fd != nil: - return newTunFromFd( - l, - *fd, - tunCidr, - c.GetInt("tun.mtu", DefaultMTU), - routes, - c.GetInt("tun.tx_queue", 500), - c.GetBool("tun.use_system_route_table", false), - ) - default: return newTun( l, @@ -51,3 +42,28 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd * ) } } + +func NewFdDeviceFromConfig(fd *int) DeviceFactory { + return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { + routes, err := parseRoutes(c, tunCidr) + if err != nil { + return nil, util.NewContextualError("Could not parse tun.routes", nil, err) + } + + unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr) + if err != nil { + return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) + } + routes = append(routes, unsafeRoutes...) + return newTunFromFd( + l, + *fd, + tunCidr, + c.GetInt("tun.mtu", DefaultMTU), + routes, + c.GetInt("tun.tx_queue", 500), + c.GetBool("tun.use_system_route_table", false), + ) + + } +} diff --git a/overlay/user.go b/overlay/user.go new file mode 100644 index 0000000..9d819ae --- /dev/null +++ b/overlay/user.go @@ -0,0 +1,63 @@ +package overlay + +import ( + "io" + "net" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/iputil" +) + +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { + return NewUserDevice(tunCidr) +} + +func NewUserDevice(tunCidr *net.IPNet) (Device, error) { + // these pipes guarantee each write/read will match 1:1 + or, ow := io.Pipe() + ir, iw := io.Pipe() + return &UserDevice{ + tunCidr: tunCidr, + outboundReader: or, + outboundWriter: ow, + inboundReader: ir, + inboundWriter: iw, + }, nil +} + +type UserDevice struct { + tunCidr *net.IPNet + + outboundReader *io.PipeReader + outboundWriter *io.PipeWriter + + inboundReader *io.PipeReader + inboundWriter *io.PipeWriter +} + +func (d *UserDevice) Activate() error { + return nil +} +func (d *UserDevice) Cidr() *net.IPNet { return d.tunCidr } +func (d *UserDevice) Name() string { return "faketun0" } +func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip } +func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { + return d, nil +} + +func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) { + return d.inboundReader, d.outboundWriter +} + +func (d *UserDevice) Read(p []byte) (n int, err error) { + return d.outboundReader.Read(p) +} +func (d *UserDevice) Write(p []byte) (n int, err error) { + return d.inboundWriter.Write(p) +} +func (d *UserDevice) Close() error { + d.inboundWriter.Close() + d.outboundWriter.Close() + return nil +} diff --git a/service/listener.go b/service/listener.go new file mode 100644 index 0000000..6d5c8a4 --- /dev/null +++ b/service/listener.go @@ -0,0 +1,36 @@ +package service + +import ( + "io" + "net" +) + +type tcpListener struct { + port uint16 + s *Service + addr *net.TCPAddr + accept chan net.Conn +} + +func (l *tcpListener) Accept() (net.Conn, error) { + conn, ok := <-l.accept + if !ok { + return nil, io.EOF + } + return conn, nil +} + +func (l *tcpListener) Close() error { + l.s.mu.Lock() + defer l.s.mu.Unlock() + delete(l.s.mu.listeners, uint16(l.addr.Port)) + + close(l.accept) + + return nil +} + +// Addr returns the listener's network address. +func (l *tcpListener) Addr() net.Addr { + return l.addr +} diff --git a/service/service.go b/service/service.go new file mode 100644 index 0000000..66ce864 --- /dev/null +++ b/service/service.go @@ -0,0 +1,248 @@ +package service + +import ( + "bytes" + "context" + "errors" + "fmt" + "log" + "math" + "net" + "os" + "strings" + "sync" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/overlay" + "golang.org/x/sync/errgroup" + "gvisor.dev/gvisor/pkg/bufferv2" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const nicID = 1 + +type Service struct { + eg *errgroup.Group + control *nebula.Control + ipstack *stack.Stack + + mu struct { + sync.Mutex + + listeners map[uint16]*tcpListener + } +} + +func New(config *config.C) (*Service, error) { + logger := logrus.New() + logger.Out = os.Stdout + + control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) + if err != nil { + return nil, err + } + control.Start() + + ctx := control.Context() + eg, ctx := errgroup.WithContext(ctx) + s := Service{ + eg: eg, + control: control, + } + s.mu.listeners = map[uint16]*tcpListener{} + + device, ok := control.Device().(*overlay.UserDevice) + if !ok { + return nil, errors.New("must be using user device") + } + + s.ipstack = stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, + }) + sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default + tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) + if tcpipErr != nil { + return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) + } + linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "") + if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil { + return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem) + } + ipv4Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4))) + s.ipstack.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID, + }, + }) + + ipNet := device.Cidr() + pa := tcpip.ProtocolAddress{ + AddressWithPrefix: tcpip.Address(ipNet.IP).WithPrefix(), + Protocol: ipv4.ProtocolNumber, + } + if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ + PEB: stack.CanBePrimaryEndpoint, // zero value default + ConfigType: stack.AddressConfigStatic, // zero value default + }); err != nil { + return nil, fmt.Errorf("error creating IP: %s", err) + } + + const tcpReceiveBufferSize = 0 + const maxInFlightConnectionAttempts = 1024 + tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler) + s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) + + reader, writer := device.Pipe() + + go func() { + <-ctx.Done() + reader.Close() + writer.Close() + }() + + // create Goroutines to forward packets between Nebula and Gvisor + eg.Go(func() error { + buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize) + for { + // this will read exactly one packet + n, err := reader.Read(buf) + if err != nil { + return err + } + packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: bufferv2.MakeWithData(bytes.Clone(buf[:n])), + }) + linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) + + if err := ctx.Err(); err != nil { + return err + } + } + }) + eg.Go(func() error { + for { + packet := linkEP.ReadContext(ctx) + if packet.IsNil() { + if err := ctx.Err(); err != nil { + return err + } + continue + } + bufView := packet.ToView() + if _, err := bufView.WriteTo(writer); err != nil { + return err + } + bufView.Release() + } + }) + + return &s, nil +} + +// DialContext dials the provided address. Currently only TCP is supported. +func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if network != "tcp" && network != "tcp4" { + return nil, errors.New("only tcp is supported") + } + + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + + fullAddr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(addr.IP), + Port: uint16(addr.Port), + } + + return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber) +} + +// Listen listens on the provided address. Currently only TCP with wildcard +// addresses are supported. +func (s *Service) Listen(network, address string) (net.Listener, error) { + if network != "tcp" && network != "tcp4" { + return nil, errors.New("only tcp is supported") + } + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) { + return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP) + } + if addr.Port == 0 { + return nil, errors.New("specific port required, got 0") + } + if addr.Port < 0 || addr.Port >= math.MaxUint16 { + return nil, fmt.Errorf("invalid port %d", addr.Port) + } + port := uint16(addr.Port) + + l := &tcpListener{ + port: port, + s: s, + addr: addr, + accept: make(chan net.Conn), + } + + s.mu.Lock() + defer s.mu.Unlock() + + if _, ok := s.mu.listeners[port]; ok { + return nil, fmt.Errorf("already listening on port %d", port) + } + s.mu.listeners[port] = l + + return l, nil +} + +func (s *Service) Wait() error { + return s.eg.Wait() +} + +func (s *Service) Close() error { + s.control.Stop() + return nil +} + +func (s *Service) tcpHandler(r *tcp.ForwarderRequest) { + endpointID := r.ID() + + s.mu.Lock() + defer s.mu.Unlock() + + l, ok := s.mu.listeners[endpointID.LocalPort] + if !ok { + r.Complete(true) + return + } + + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + log.Printf("got error creating endpoint %q", err) + r.Complete(true) + return + } + r.Complete(false) + ep.SocketOptions().SetKeepAlive(true) + + conn := gonet.NewTCPConn(&wq, ep) + l.accept <- conn +} diff --git a/service/service_test.go b/service/service_test.go new file mode 100644 index 0000000..d1909cd --- /dev/null +++ b/service/service_test.go @@ -0,0 +1,165 @@ +package service + +import ( + "bytes" + "context" + "errors" + "net" + "testing" + "time" + + "dario.cat/mergo" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/e2e" + "golang.org/x/sync/errgroup" + "gopkg.in/yaml.v2" +) + +type m map[string]interface{} + +func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service { + + vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} + copy(vpnIpNet.IP, udpIp) + + _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) + caB, err := caCrt.MarshalToPEM() + if err != nil { + panic(err) + } + + mc := m{ + "pki": m{ + "ca": string(caB), + "cert": string(myPEM), + "key": string(myPrivKey), + }, + //"tun": m{"disabled": true}, + "firewall": m{ + "outbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + }, + "timers": m{ + "pending_deletion_interval": 2, + "connection_alive_interval": 2, + }, + "handshakes": m{ + "try_interval": "200ms", + }, + } + + if overrides != nil { + err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + mc = overrides + } + + cb, err := yaml.Marshal(mc) + if err != nil { + panic(err) + } + + var c config.C + if err := c.LoadString(string(cb)); err != nil { + panic(err) + } + + s, err := New(&c) + if err != nil { + panic(err) + } + return s +} + +func TestService(t *testing.T) { + ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{ + "static_host_map": m{}, + "lighthouse": m{ + "am_lighthouse": true, + }, + "listen": m{ + "host": "0.0.0.0", + "port": 4243, + }, + }) + b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{ + "static_host_map": m{ + "10.0.0.1": []string{"localhost:4243"}, + }, + "lighthouse": m{ + "hosts": []string{"10.0.0.1"}, + "interval": 1, + }, + }) + + ln, err := a.Listen("tcp", ":1234") + if err != nil { + t.Fatal(err) + } + var eg errgroup.Group + eg.Go(func() error { + conn, err := ln.Accept() + if err != nil { + return err + } + defer conn.Close() + + t.Log("accepted connection") + + if _, err := conn.Write([]byte("server msg")); err != nil { + return err + } + + t.Log("server: wrote message") + + data := make([]byte, 100) + n, err := conn.Read(data) + if err != nil { + return err + } + data = data[:n] + if !bytes.Equal(data, []byte("client msg")) { + return errors.New("got invalid message from client") + } + t.Log("server: read message") + return conn.Close() + }) + + c, err := b.DialContext(context.Background(), "tcp", "10.0.0.1:1234") + if err != nil { + t.Fatal(err) + } + if _, err := c.Write([]byte("client msg")); err != nil { + t.Fatal(err) + } + + data := make([]byte, 100) + n, err := c.Read(data) + if err != nil { + t.Fatal(err) + } + data = data[:n] + if !bytes.Equal(data, []byte("server msg")) { + t.Fatal("got invalid message from client") + } + + if err := c.Close(); err != nil { + t.Fatal(err) + } + + if err := eg.Wait(); err != nil { + t.Fatal(err) + } +}