From f22b4b584d64c945da3dfc22361a76b1f87bd8af Mon Sep 17 00:00:00 2001 From: Slack Security Team Date: Tue, 19 Nov 2019 17:00:20 +0000 Subject: [PATCH] Public Release --- .gitignore | 10 + AUTHORS | 9 + LICENSE | 24 + Makefile | 77 ++ README.md | 91 ++ bits.go | 157 ++++ bits_test.go | 223 +++++ cert.go | 159 ++++ cert/Makefile | 9 + cert/README.md | 15 + cert/ca.go | 120 +++ cert/cert.go | 445 ++++++++++ cert/cert.pb.go | 202 +++++ cert/cert.proto | 27 + cert/cert_test.go | 373 +++++++++ cidr_radix.go | 147 ++++ cidr_radix_test.go | 118 +++ cmd/nebula-cert/ca.go | 124 +++ cmd/nebula-cert/ca_test.go | 132 +++ cmd/nebula-cert/keygen.go | 65 ++ cmd/nebula-cert/keygen_test.go | 92 ++ cmd/nebula-cert/main.go | 137 +++ cmd/nebula-cert/main_test.go | 81 ++ cmd/nebula-cert/print.go | 80 ++ cmd/nebula-cert/print_test.go | 119 +++ cmd/nebula-cert/sign.go | 227 +++++ cmd/nebula-cert/sign_test.go | 281 +++++++ cmd/nebula-cert/test_unix.go | 4 + cmd/nebula-cert/test_windows.go | 4 + cmd/nebula-cert/verify.go | 86 ++ cmd/nebula-cert/verify_test.go | 141 ++++ cmd/nebula/main.go | 43 + config.go | 338 ++++++++ config_test.go | 141 ++++ connection_manager.go | 253 ++++++ connection_manager_test.go | 141 ++++ connection_state.go | 75 ++ dns_server.go | 125 +++ dns_server_test.go | 19 + examples/config.yaml | 160 ++++ examples/quickstart-vagrant/README.md | 154 ++++ examples/quickstart-vagrant/Vagrantfile | 40 + .../quickstart-vagrant/ansible/ansible.cfg | 4 + .../ansible/filter_plugins/to_nebula_ip.py | 21 + examples/quickstart-vagrant/ansible/inventory | 11 + .../quickstart-vagrant/ansible/playbook.yml | 20 + .../ansible/roles/nebula/defaults/main.yml | 3 + .../roles/nebula/files/systemd.nebula.service | 15 + .../roles/nebula/files/vagrant-test-ca.crt | 5 + .../roles/nebula/files/vagrant-test-ca.key | 4 + .../ansible/roles/nebula/handlers/main.yml | 5 + .../ansible/roles/nebula/tasks/main.yml | 56 ++ .../roles/nebula/templates/config.yml.j2 | 84 ++ .../ansible/roles/nebula/vars/main.yml | 7 + examples/quickstart-vagrant/requirements.yml | 1 + examples/service_scripts/nebula.init.d.sh | 51 ++ examples/service_scripts/nebula.service | 15 + firewall.go | 789 ++++++++++++++++++ firewall_test.go | 687 +++++++++++++++ go.mod | 32 + go.sum | 112 +++ handshake.go | 82 ++ handshake_ix.go | 356 ++++++++ handshake_manager.go | 200 +++++ handshake_manager_test.go | 191 +++++ header.go | 188 +++++ header_test.go | 118 +++ hostmap.go | 743 +++++++++++++++++ hostmap_test.go | 166 ++++ inside.go | 201 +++++ interface.go | 277 ++++++ lighthouse.go | 368 ++++++++ lighthouse_test.go | 76 ++ main.go | 321 +++++++ main_test.go | 1 + metadata.go | 18 + nebula.pb.go | 457 ++++++++++ nebula.proto | 59 ++ noise.go | 60 ++ outside.go | 410 +++++++++ outside_test.go | 80 ++ ssh.go | 727 ++++++++++++++++ sshd/command.go | 161 ++++ sshd/server.go | 182 ++++ sshd/session.go | 182 ++++ sshd/writer.go | 32 + stats.go | 89 ++ timeout.go | 191 +++++ timeout_system.go | 196 +++++ timeout_system_test.go | 134 +++ timeout_test.go | 138 +++ tun_common.go | 108 +++ tun_darwin.go | 59 ++ tun_linux.go | 249 ++++++ tun_test.go | 102 +++ tun_windows.go | 72 ++ udp_darwin.go | 34 + udp_generic.go | 123 +++ udp_linux.go | 315 +++++++ udp_linux_amd64.go | 50 ++ udp_linux_arm.go | 47 ++ udp_linux_arm64.go | 50 ++ udp_windows.go | 22 + 103 files changed, 14825 insertions(+) create mode 100644 .gitignore create mode 100644 AUTHORS create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 README.md create mode 100644 bits.go create mode 100644 bits_test.go create mode 100644 cert.go create mode 100644 cert/Makefile create mode 100644 cert/README.md create mode 100644 cert/ca.go create mode 100644 cert/cert.go create mode 100644 cert/cert.pb.go create mode 100644 cert/cert.proto create mode 100644 cert/cert_test.go create mode 100644 cidr_radix.go create mode 100644 cidr_radix_test.go create mode 100644 cmd/nebula-cert/ca.go create mode 100644 cmd/nebula-cert/ca_test.go create mode 100644 cmd/nebula-cert/keygen.go create mode 100644 cmd/nebula-cert/keygen_test.go create mode 100644 cmd/nebula-cert/main.go create mode 100644 cmd/nebula-cert/main_test.go create mode 100644 cmd/nebula-cert/print.go create mode 100644 cmd/nebula-cert/print_test.go create mode 100644 cmd/nebula-cert/sign.go create mode 100644 cmd/nebula-cert/sign_test.go create mode 100644 cmd/nebula-cert/test_unix.go create mode 100644 cmd/nebula-cert/test_windows.go create mode 100644 cmd/nebula-cert/verify.go create mode 100644 cmd/nebula-cert/verify_test.go create mode 100644 cmd/nebula/main.go create mode 100644 config.go create mode 100644 config_test.go create mode 100644 connection_manager.go create mode 100644 connection_manager_test.go create mode 100644 connection_state.go create mode 100644 dns_server.go create mode 100644 dns_server_test.go create mode 100644 examples/config.yaml create mode 100644 examples/quickstart-vagrant/README.md create mode 100644 examples/quickstart-vagrant/Vagrantfile create mode 100644 examples/quickstart-vagrant/ansible/ansible.cfg create mode 100644 examples/quickstart-vagrant/ansible/filter_plugins/to_nebula_ip.py create mode 100644 examples/quickstart-vagrant/ansible/inventory create mode 100644 examples/quickstart-vagrant/ansible/playbook.yml create mode 100644 examples/quickstart-vagrant/ansible/roles/nebula/defaults/main.yml create mode 100644 examples/quickstart-vagrant/ansible/roles/nebula/files/systemd.nebula.service create mode 100644 examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt create mode 100644 examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key create mode 100644 examples/quickstart-vagrant/ansible/roles/nebula/handlers/main.yml create mode 100644 examples/quickstart-vagrant/ansible/roles/nebula/tasks/main.yml create mode 100644 examples/quickstart-vagrant/ansible/roles/nebula/templates/config.yml.j2 create mode 100644 examples/quickstart-vagrant/ansible/roles/nebula/vars/main.yml create mode 100644 examples/quickstart-vagrant/requirements.yml create mode 100644 examples/service_scripts/nebula.init.d.sh create mode 100644 examples/service_scripts/nebula.service create mode 100644 firewall.go create mode 100644 firewall_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 handshake.go create mode 100644 handshake_ix.go create mode 100644 handshake_manager.go create mode 100644 handshake_manager_test.go create mode 100644 header.go create mode 100644 header_test.go create mode 100644 hostmap.go create mode 100644 hostmap_test.go create mode 100644 inside.go create mode 100644 interface.go create mode 100644 lighthouse.go create mode 100644 lighthouse_test.go create mode 100644 main.go create mode 100644 main_test.go create mode 100644 metadata.go create mode 100644 nebula.pb.go create mode 100644 nebula.proto create mode 100644 noise.go create mode 100644 outside.go create mode 100644 outside_test.go create mode 100644 ssh.go create mode 100644 sshd/command.go create mode 100644 sshd/server.go create mode 100644 sshd/session.go create mode 100644 sshd/writer.go create mode 100644 stats.go create mode 100644 timeout.go create mode 100644 timeout_system.go create mode 100644 timeout_system_test.go create mode 100644 timeout_test.go create mode 100644 tun_common.go create mode 100644 tun_darwin.go create mode 100644 tun_linux.go create mode 100644 tun_test.go create mode 100644 tun_windows.go create mode 100644 udp_darwin.go create mode 100644 udp_generic.go create mode 100644 udp_linux.go create mode 100644 udp_linux_amd64.go create mode 100644 udp_linux_arm.go create mode 100644 udp_linux_arm64.go create mode 100644 udp_windows.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6dd38c4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +/nebula +/nebula-cert +/nebula-arm +/nebula-arm6 +/nebula-darwin +/nebula.exe +/cert/*.crt +/cert/*.key +/coverage.out +/cpu.pprof diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 0000000..261ed36 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,9 @@ +# This is the official list of Nebula authors for copyright purposes. + +# Names should be added to this file as: +# Name or Organization +# The email address is not required for organizations. + +Slack Technologies, Inc. +Nate Brown +Ryan Huber diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..dbc2951 --- /dev/null +++ b/LICENSE @@ -0,0 +1,24 @@ +MIT License + +Copyright (c) 2018-2019 Slack Technologies, Inc. + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5e64ab8 --- /dev/null +++ b/Makefile @@ -0,0 +1,77 @@ +BUILD_NUMBER ?= dev+$(shell date -u '+%Y%m%d%H%M%S') +GO111MODULE = on +export GO111MODULE + +all: + make bin + make bin-arm + make bin-arm6 + make bin-arm64 + make bin-darwin + make bin-windows + +bin: + go build -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula ./cmd/nebula + go build -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula-cert ./cmd/nebula-cert + +install: + go install -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula + go install -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula-cert + +bin-arm: + GOARCH=arm GOOS=linux go build -o nebula-arm -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula + +bin-arm6: + GOARCH=arm GOARM=6 GOOS=linux go build -o nebula-arm6 -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula + +bin-arm64: + GOARCH=arm64 GOOS=linux go build -o nebula-arm64 -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula + + +bin-vagrant: + GOARCH=amd64 GOOS=linux go build -o nebula -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula + GOARCH=amd64 GOOS=linux go build -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula-cert ./cmd/nebula-cert +bin-darwin: + GOARCH=amd64 GOOS=darwin go build -o nebula-darwin -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula + +bin-windows: + GOARCH=amd64 GOOS=windows go build -o nebula.exe -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula + +bin-linux: + GOARCH=amd64 GOOS=linux go build -o ./nebula -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula + GOARCH=amd64 GOOS=linux go build -o ./nebula-cert -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula-cert + +vet: + go vet -v ./... + +test: + go test -v ./... + +test-cov-html: + go test -coverprofile=coverage.out + go tool cover -html=coverage.out + +bench: + go test -bench=. + +bench-cpu: + go test -bench=. -benchtime=5s -cpuprofile=cpu.pprof + go tool pprof go-audit.test cpu.pprof + +bench-cpu-long: + go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof + go tool pprof go-audit.test cpu.pprof + +proto: nebula.pb.go cert/cert.pb.go + +nebula.pb.go: nebula.proto .FORCE + go build github.com/golang/protobuf/protoc-gen-go + PATH="$(PWD):$(PATH)" protoc --go_out=. $< + rm protoc-gen-go + +cert/cert.pb.go: cert/cert.proto .FORCE + $(MAKE) -C cert cert.pb.go + +.FORCE: +.PHONY: test test-cov-html bench bench-cpu bench-cpu-long bin proto +.DEFAULT_GOAL := bin diff --git a/README.md b/README.md new file mode 100644 index 0000000..baa5313 --- /dev/null +++ b/README.md @@ -0,0 +1,91 @@ +## What is Nebula? +Nebula is a scalable overlay networking tool with a focus on performance, simplicity and security. +It lets you seamlessly connect computers anywhere in the world. Nebula is portable, and runs on Linux, OSX, and Windows. +(Also: keep this quiet, but we have an early prototype running on iOS). +It can be used to connect a small number of computers, but is also able to connect tens of thousands of computers. + +Nebula incorporates a number of existing concepts like encryption, security groups, certificates, +and tunneling, and each of those individual pieces existed before Nebula in various forms. +What makes Nebula different to existing offerings is that it brings all of these ideas together, +resulting in a sum that is greater than its individual parts. + +You can read more about Nebula [here](https://medium.com/p/884110a5579). + +## Technical Overview + +Nebula is a mutually authenticated peer-to-peer software defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/). +Nebula uses certificates to assert a node's IP address, name, and membership within user-defined groups. +Nebula's user-defined groups allow for provider agnostic traffic filtering between nodes. +Discovery nodes allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs. +Users can move data between nodes in any number of cloud service providers, datacenters, and endpoints, without needing to maintain a particular addressing scheme. + +Nebula uses elliptic curve Diffie-Hellman key exchange, and AES-256-GCM in its default configuration. + +Nebula was created to provide a mechanism for groups hosts to communicate securely, even across the internet, while enabling expressive firewall definitions similar in style to cloud security groups. + +## Getting started (quickly) + +To set up a Nebula network, you'll need: + +#### 1. The [Nebula binaries](https://github.com/slackhq/nebula/releases) for your specific platform. Specifically you'll need `nebula-cert` and the specific nebula binary for each platform you use. + +#### 2. (Optional, but you really should..) At least one discovery node with a routable IP address, which we call a lighthouse. + +Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $5/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses. + + Once you have launched an instance, ensure that Nebula udp traffic (default port udp/4242) can reach it over the internet. + + +#### 3. A Nebula certificate authority, which will be the root of trust for a particular Nebula network. + + ``` + ./nebula-cert ca -name "Myorganization, Inc" + ``` + This will create files named `ca.key` and `ca.cert` in the current directory. The `ca.key` file is the most sensitive file you'll create, because it is the key used to sign the certificates for individual nebula nodes/hosts. Please store this file somewhere safe, preferably with strong encryption. + +#### 4. Nebula host keys and certificates generated from that certificate authority +This assumes you have three nodes, named lighthouse1, host1, host3. You can name the nodes any way you'd like, including FQDN. You'll also need to choose IP addresses and the associated subnet. In this example, we are creating a nebula network that will use 192.168.100.x/24 as its network range. This example also demonstrates nebula groups, which can later be used to define traffic rules in a nebula network. +``` +./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24" +./nebula-cert sign -name "laptop" -ip "192.168.100.2/24" -groups "laptop,home,ssh" +./nebula-cert sign -name "server1" -ip "192.168.100.9/24" -groups "servers" +./nebula-cert sign -name "host3" -ip "192.168.100.9/24" +``` + +#### 5. Configuration files for each host +Download a copy of the nebula [example configuration](https://github.com/slackhq/nebula/blob/master/examples/config.yaml). + +* On the lighthouse node, you'll need to ensure `am_lighthouse: true` is set. + +* On the individual hosts, ensure the lighthouse is defined properly in the `static_host_map` section, and is added to the lighthouse `hosts` section. + + +#### 6. Copy nebula credentials, configuration, and binaries to each host + +For each host, copy the nebula binary to the host, along with `config.yaml` from step 5, and the files `ca.crt`, `{host}.crt`, and `{host}.key` from step 2. + +**DO NOT COPY `ca.key` TO INDIVIDUAL NODES.** + +#### 7. Run nebula on each host +``` +./nebula -config /path/to/config.yaml +``` + +## Building Nebula from source + +Download go and clone this repo. Change to the nebula directory. + +To build nebula for all platforms: +`make all` + +To build nebula for a specific platform (ex, Windows): +`make bin-windows` + +See the [Makefile](Makefile) for more details on build targets + +## Credits + +Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang. + + + diff --git a/bits.go b/bits.go new file mode 100644 index 0000000..49cadc1 --- /dev/null +++ b/bits.go @@ -0,0 +1,157 @@ +package nebula + +import ( + "github.com/rcrowley/go-metrics" + "github.com/sirupsen/logrus" +) + +type Bits struct { + length uint64 + current uint64 + bits []bool + firstSeen bool + lostCounter metrics.Counter + dupeCounter metrics.Counter + outOfWindowCounter metrics.Counter +} + +func NewBits(bits uint64) *Bits { + return &Bits{ + length: bits, + bits: make([]bool, bits, bits), + current: 0, + lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil), + dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil), + outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil), + } +} + +func (b *Bits) Check(i uint64) bool { + // If i is the next number, return true. + if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) { + return true + } + + // If i is within the window, check if it's been set already. The first window will fail this check + if i > b.current-b.length { + return !b.bits[i%b.length] + } + + // If i is within the first window + if i < b.length { + return !b.bits[i%b.length] + } + + // Not within the window + l.Debugf("rejected a packet (top) %d %d\n", b.current, i) + return false +} + +func (b *Bits) Update(i uint64) bool { + // If i is the next number, return true and update current. + if i == b.current+1 { + // Report missed packets, we can only understand what was missed after the first window has been gone through + if i > b.length && b.bits[i%b.length] == false { + b.lostCounter.Inc(1) + } + b.bits[i%b.length] = true + b.current = i + return true + } + + // If i packet is greater than current but less than the maximum length of our bitmap, + // flip everything in between to false and move ahead. + if i > b.current && i < b.current+b.length { + // In between current and i need to be zero'd to allow those packets to come in later + for n := b.current + 1; n < i; n++ { + b.bits[n%b.length] = false + } + + b.bits[i%b.length] = true + b.current = i + //l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current) + return true + } + + // If i is greater than the delta between current and the total length of our bitmap, + // just flip everything in the map and move ahead. + if i >= b.current+b.length { + // The current window loss will be accounted for later, only record the jump as loss up until then + lost := maxInt64(0, int64(i-b.current-b.length)) + //TODO: explain this + if b.current == 0 { + lost++ + } + + for n := range b.bits { + // Don't want to count the first window as a loss + //TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed + //if b.bits[n] == false { + // lost++ + //} + b.bits[n] = false + } + + b.lostCounter.Inc(lost) + + if l.Level >= logrus.DebugLevel { + l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}). + Debug("Receive window") + } + b.bits[i%b.length] = true + b.current = i + return true + } + + // Allow for the 0 packet to come in within the first window + if i == 0 && b.firstSeen == false && b.current < b.length { + b.firstSeen = true + b.bits[i%b.length] = true + return true + } + + // If i is within the window of current minus length (the total pat window size), + // allow it and flip to true but to NOT change current. We also have to account for the first window + if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current { + if b.current == i { + if l.Level >= logrus.DebugLevel { + l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}). + Debug("Receive window") + } + b.dupeCounter.Inc(1) + return false + } + + if b.bits[i%b.length] == true { + if l.Level >= logrus.DebugLevel { + l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}). + Debug("Receive window") + } + b.dupeCounter.Inc(1) + return false + } + + b.bits[i%b.length] = true + return true + + } + + // In all other cases, fail and don't change current. + b.outOfWindowCounter.Inc(1) + if l.Level >= logrus.DebugLevel { + l.WithField("accepted", false). + WithField("currentCounter", b.current). + WithField("incomingCounter", i). + WithField("reason", "nonsense"). + Debug("Receive window") + } + return false +} + +func maxInt64(a, b int64) int64 { + if a > b { + return a + } + + return b +} diff --git a/bits_test.go b/bits_test.go new file mode 100644 index 0000000..f918c82 --- /dev/null +++ b/bits_test.go @@ -0,0 +1,223 @@ +package nebula + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBits(t *testing.T) { + b := NewBits(10) + + // make sure it is the right size + assert.Len(t, b.bits, 10) + + // This is initialized to zero - receive one. This should work. + + assert.True(t, b.Check(1)) + u := b.Update(1) + assert.True(t, u) + assert.EqualValues(t, 1, b.current) + g := []bool{false, true, false, false, false, false, false, false, false, false} + assert.Equal(t, g, b.bits) + + // Receive two + assert.True(t, b.Check(2)) + u = b.Update(2) + assert.True(t, u) + assert.EqualValues(t, 2, b.current) + g = []bool{false, true, true, false, false, false, false, false, false, false} + assert.Equal(t, g, b.bits) + + // Receive two again - it will fail + assert.False(t, b.Check(2)) + u = b.Update(2) + assert.False(t, u) + assert.EqualValues(t, 2, b.current) + + // Jump ahead to 15, which should clear everything and set the 6th element + assert.True(t, b.Check(15)) + u = b.Update(15) + assert.True(t, u) + assert.EqualValues(t, 15, b.current) + g = []bool{false, false, false, false, false, true, false, false, false, false} + assert.Equal(t, g, b.bits) + + // Mark 14, which is allowed because it is in the window + assert.True(t, b.Check(14)) + u = b.Update(14) + assert.True(t, u) + assert.EqualValues(t, 15, b.current) + g = []bool{false, false, false, false, true, true, false, false, false, false} + assert.Equal(t, g, b.bits) + + // Mark 5, which is not allowed because it is not in the window + assert.False(t, b.Check(5)) + u = b.Update(5) + assert.False(t, u) + assert.EqualValues(t, 15, b.current) + g = []bool{false, false, false, false, true, true, false, false, false, false} + assert.Equal(t, g, b.bits) + + // make sure we handle wrapping around once to the current position + b = NewBits(10) + assert.True(t, b.Update(1)) + assert.True(t, b.Update(11)) + assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits) + + // Walk through a few windows in order + b = NewBits(10) + for i := uint64(0); i <= 100; i++ { + assert.True(t, b.Check(i), "Error while checking %v", i) + assert.True(t, b.Update(i), "Error while updating %v", i) + } +} + +func TestBitsDupeCounter(t *testing.T) { + b := NewBits(10) + b.lostCounter.Clear() + b.dupeCounter.Clear() + b.outOfWindowCounter.Clear() + + assert.True(t, b.Update(1)) + assert.Equal(t, int64(0), b.dupeCounter.Count()) + + assert.False(t, b.Update(1)) + assert.Equal(t, int64(1), b.dupeCounter.Count()) + + assert.True(t, b.Update(2)) + assert.Equal(t, int64(1), b.dupeCounter.Count()) + + assert.True(t, b.Update(3)) + assert.Equal(t, int64(1), b.dupeCounter.Count()) + + assert.False(t, b.Update(1)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + assert.Equal(t, int64(2), b.dupeCounter.Count()) + assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) +} + +func TestBitsOutOfWindowCounter(t *testing.T) { + b := NewBits(10) + b.lostCounter.Clear() + b.dupeCounter.Clear() + b.outOfWindowCounter.Clear() + + assert.True(t, b.Update(20)) + assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) + + assert.True(t, b.Update(21)) + assert.True(t, b.Update(22)) + assert.True(t, b.Update(23)) + assert.True(t, b.Update(24)) + assert.True(t, b.Update(25)) + assert.True(t, b.Update(26)) + assert.True(t, b.Update(27)) + assert.True(t, b.Update(28)) + assert.True(t, b.Update(29)) + assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) + + assert.False(t, b.Update(0)) + assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) + + //tODO: make sure lostcounter doesn't increase in orderly increment + assert.Equal(t, int64(20), b.lostCounter.Count()) + assert.Equal(t, int64(0), b.dupeCounter.Count()) + assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) +} + +func TestBitsLostCounter(t *testing.T) { + b := NewBits(10) + b.lostCounter.Clear() + b.dupeCounter.Clear() + b.outOfWindowCounter.Clear() + + //assert.True(t, b.Update(0)) + assert.True(t, b.Update(0)) + assert.True(t, b.Update(20)) + assert.True(t, b.Update(21)) + assert.True(t, b.Update(22)) + assert.True(t, b.Update(23)) + assert.True(t, b.Update(24)) + assert.True(t, b.Update(25)) + assert.True(t, b.Update(26)) + assert.True(t, b.Update(27)) + assert.True(t, b.Update(28)) + assert.True(t, b.Update(29)) + assert.Equal(t, int64(20), b.lostCounter.Count()) + assert.Equal(t, int64(0), b.dupeCounter.Count()) + assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) + + b = NewBits(10) + b.lostCounter.Clear() + b.dupeCounter.Clear() + b.outOfWindowCounter.Clear() + + assert.True(t, b.Update(0)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + assert.True(t, b.Update(9)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + // 10 will set 0 index, 0 was already set, no lost packets + assert.True(t, b.Update(10)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + // 11 will set 1 index, 1 was missed, we should see 1 packet lost + assert.True(t, b.Update(11)) + assert.Equal(t, int64(1), b.lostCounter.Count()) + // Now let's fill in the window, should end up with 8 lost packets + assert.True(t, b.Update(12)) + assert.True(t, b.Update(13)) + assert.True(t, b.Update(14)) + assert.True(t, b.Update(15)) + assert.True(t, b.Update(16)) + assert.True(t, b.Update(17)) + assert.True(t, b.Update(18)) + assert.True(t, b.Update(19)) + assert.Equal(t, int64(8), b.lostCounter.Count()) + + // Jump ahead by a window size + assert.True(t, b.Update(29)) + assert.Equal(t, int64(8), b.lostCounter.Count()) + // Now lets walk ahead normally through the window, the missed packets should fill in + assert.True(t, b.Update(30)) + assert.True(t, b.Update(31)) + assert.True(t, b.Update(32)) + assert.True(t, b.Update(33)) + assert.True(t, b.Update(34)) + assert.True(t, b.Update(35)) + assert.True(t, b.Update(36)) + assert.True(t, b.Update(37)) + assert.True(t, b.Update(38)) + // 39 packets tracked, 22 seen, 17 lost + assert.Equal(t, int64(17), b.lostCounter.Count()) + + // Jump ahead by 2 windows, should have recording 1 full window missing + assert.True(t, b.Update(58)) + assert.Equal(t, int64(27), b.lostCounter.Count()) + // Now lets walk ahead normally through the window, the missed packets should fill in from this window + assert.True(t, b.Update(59)) + assert.True(t, b.Update(60)) + assert.True(t, b.Update(61)) + assert.True(t, b.Update(62)) + assert.True(t, b.Update(63)) + assert.True(t, b.Update(64)) + assert.True(t, b.Update(65)) + assert.True(t, b.Update(66)) + assert.True(t, b.Update(67)) + // 68 packets tracked, 32 seen, 36 missed + assert.Equal(t, int64(36), b.lostCounter.Count()) + assert.Equal(t, int64(0), b.dupeCounter.Count()) + assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) +} + +func BenchmarkBits(b *testing.B) { + z := NewBits(10) + for n := 0; n < b.N; n++ { + for i, _ := range z.bits { + z.bits[i] = true + } + for i, _ := range z.bits { + z.bits[i] = false + } + + } +} diff --git a/cert.go b/cert.go new file mode 100644 index 0000000..ff75922 --- /dev/null +++ b/cert.go @@ -0,0 +1,159 @@ +package nebula + +import ( + "errors" + "fmt" + "io/ioutil" + "strings" + "time" + + "github.com/slackhq/nebula/cert" +) + +var trustedCAs *cert.NebulaCAPool + +type CertState struct { + certificate *cert.NebulaCertificate + rawCertificate []byte + rawCertificateNoKey []byte + publicKey []byte + privateKey []byte +} + +func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) { + // Marshal the certificate to ensure it is valid + rawCertificate, err := certificate.Marshal() + if err != nil { + return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) + } + + publicKey := certificate.Details.PublicKey + cs := &CertState{ + rawCertificate: rawCertificate, + certificate: certificate, // PublicKey has been set to nil above + privateKey: privateKey, + publicKey: publicKey, + } + + cs.certificate.Details.PublicKey = nil + rawCertNoKey, err := cs.certificate.Marshal() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate no key: %s", err) + } + cs.rawCertificateNoKey = rawCertNoKey + // put public key back + cs.certificate.Details.PublicKey = cs.publicKey + return cs, nil +} + +func NewCertStateFromConfig(c *Config) (*CertState, error) { + var pemPrivateKey []byte + var err error + + privPathOrPEM := c.GetString("pki.key", "") + if privPathOrPEM == "" { + // Support backwards compat with the old x509 + //TODO: remove after this is rolled out everywhere - NB 2018/02/23 + privPathOrPEM = c.GetString("x509.key", "") + } + + if privPathOrPEM == "" { + return nil, errors.New("no pki.key path or PEM data provided") + } + + if strings.Contains(privPathOrPEM, "-----BEGIN") { + pemPrivateKey = []byte(privPathOrPEM) + privPathOrPEM = "" + } else { + pemPrivateKey, err = ioutil.ReadFile(privPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) + } + } + + rawKey, _, err := cert.UnmarshalX25519PrivateKey(pemPrivateKey) + if err != nil { + return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + + var rawCert []byte + + pubPathOrPEM := c.GetString("pki.cert", "") + if pubPathOrPEM == "" { + // Support backwards compat with the old x509 + //TODO: remove after this is rolled out everywhere - NB 2018/02/23 + pubPathOrPEM = c.GetString("x509.cert", "") + } + + if pubPathOrPEM == "" { + return nil, errors.New("no pki.cert path or PEM data provided") + } + + if strings.Contains(pubPathOrPEM, "-----BEGIN") { + rawCert = []byte(pubPathOrPEM) + pubPathOrPEM = "" + } else { + rawCert, err = ioutil.ReadFile(pubPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err) + } + } + + nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) + if err != nil { + return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) + } + + if nebulaCert.Expired(time.Now()) { + return nil, fmt.Errorf("nebula certificate for this host is expired") + } + + if len(nebulaCert.Details.Ips) == 0 { + return nil, fmt.Errorf("no IPs encoded in certificate") + } + + if err = nebulaCert.VerifyPrivateKey(rawKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + + return NewCertState(nebulaCert, rawKey) +} + +func loadCAFromConfig(c *Config) (*cert.NebulaCAPool, error) { + var rawCA []byte + var err error + + caPathOrPEM := c.GetString("pki.ca", "") + if caPathOrPEM == "" { + // Support backwards compat with the old x509 + //TODO: remove after this is rolled out everywhere - NB 2018/02/23 + caPathOrPEM = c.GetString("x509.ca", "") + } + + if caPathOrPEM == "" { + return nil, errors.New("no pki.ca path or PEM data provided") + } + + if strings.Contains(caPathOrPEM, "-----BEGIN") { + rawCA = []byte(caPathOrPEM) + caPathOrPEM = "" + } else { + rawCA, err = ioutil.ReadFile(caPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) + } + } + + CAs, err := cert.NewCAPoolFromBytes(rawCA) + if err != nil { + return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) + } + + // pki.blacklist entered the scene at about the same time we aliased x509 to pki, not supporting backwards compat + for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) { + l.WithField("fingerprint", fp).Infof("Blacklisting cert") + CAs.BlacklistFingerprint(fp) + } + + return CAs, nil +} diff --git a/cert/Makefile b/cert/Makefile new file mode 100644 index 0000000..c1c6448 --- /dev/null +++ b/cert/Makefile @@ -0,0 +1,9 @@ +GO111MODULE = on +export GO111MODULE + +cert.pb.go: cert.proto .FORCE + go build github.com/golang/protobuf/protoc-gen-go + PATH="$(PWD):$(PATH)" protoc --go_out=. $< + rm protoc-gen-go + +.FORCE: diff --git a/cert/README.md b/cert/README.md new file mode 100644 index 0000000..ae19a28 --- /dev/null +++ b/cert/README.md @@ -0,0 +1,15 @@ +## `cert` + +This is a library for interacting with `nebula` style certificates and authorities. + +A `protobuf` definition of the certificate format is also included + +### Compiling the protobuf definition + +Make sure you have `protoc` installed. + +To compile for `go` with the same version of protobuf specified in go.mod: + +```bash +make +``` diff --git a/cert/ca.go b/cert/ca.go new file mode 100644 index 0000000..43a47a2 --- /dev/null +++ b/cert/ca.go @@ -0,0 +1,120 @@ +package cert + +import ( + "fmt" + "strings" + "time" +) + +type NebulaCAPool struct { + CAs map[string]*NebulaCertificate + certBlacklist map[string]struct{} +} + +// NewCAPool creates a CAPool +func NewCAPool() *NebulaCAPool { + ca := NebulaCAPool{ + CAs: make(map[string]*NebulaCertificate), + certBlacklist: make(map[string]struct{}), + } + + return &ca +} + +func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, error) { + pool := NewCAPool() + var err error + for { + caPEMs, err = pool.AddCACertificate(caPEMs) + if err != nil { + return nil, err + } + if caPEMs == nil || len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" { + break + } + } + + return pool, nil +} + +// AddCACertificate verifies a Nebula CA certificate and adds it to the pool +// Only the first pem encoded object will be consumed, any remaining bytes are returned. +// Parsed certificates will be verified and must be a CA +func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) { + c, pemBytes, err := UnmarshalNebulaCertificateFromPEM(pemBytes) + if err != nil { + return pemBytes, err + } + + if !c.Details.IsCA { + return pemBytes, fmt.Errorf("provided certificate was not a CA; %s", c.Details.Name) + } + + if !c.CheckSignature(c.Details.PublicKey) { + return pemBytes, fmt.Errorf("provided certificate was not self signed; %s", c.Details.Name) + } + + if c.Expired(time.Now()) { + return pemBytes, fmt.Errorf("provided CA certificate is expired; %s", c.Details.Name) + } + + sum, err := c.Sha256Sum() + if err != nil { + return pemBytes, fmt.Errorf("could not calculate shasum for provided CA; error: %s; %s", err, c.Details.Name) + } + + ncp.CAs[sum] = c + return pemBytes, nil +} + +// BlacklistFingerprint adds a cert fingerprint to the blacklist +func (ncp *NebulaCAPool) BlacklistFingerprint(f string) { + ncp.certBlacklist[f] = struct{}{} +} + +// ResetCertBlacklist removes all previously blacklisted cert fingerprints +func (ncp *NebulaCAPool) ResetCertBlacklist() { + ncp.certBlacklist = make(map[string]struct{}) +} + +// IsBlacklisted returns true if the fingerprint fails to generate or has been explicitly blacklisted +func (ncp *NebulaCAPool) IsBlacklisted(c *NebulaCertificate) bool { + h, err := c.Sha256Sum() + if err != nil { + return true + } + + if _, ok := ncp.certBlacklist[h]; ok { + return true + } + + return false +} + +// GetCAForCert attempts to return the signing certificate for the provided certificate. +// No signature validation is performed +func (ncp *NebulaCAPool) GetCAForCert(c *NebulaCertificate) (*NebulaCertificate, error) { + if c.Details.Issuer == "" { + return nil, fmt.Errorf("no issuer in certificate") + } + + signer, ok := ncp.CAs[c.Details.Issuer] + if ok { + return signer, nil + } + + return nil, fmt.Errorf("could not find ca for the certificate") +} + +// GetFingerprints returns an array of trusted CA fingerprints +func (ncp *NebulaCAPool) GetFingerprints() []string { + fp := make([]string, len(ncp.CAs)) + + i := 0 + for k := range ncp.CAs { + fp[i] = k + i++ + } + + return fp +} diff --git a/cert/cert.go b/cert/cert.go new file mode 100644 index 0000000..492f183 --- /dev/null +++ b/cert/cert.go @@ -0,0 +1,445 @@ +package cert + +import ( + "crypto" + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "encoding/pem" + "fmt" + "net" + "time" + + "bytes" + "encoding/json" + "github.com/golang/protobuf/proto" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ed25519" +) + +const publicKeyLen = 32 + +const ( + CertBanner = "NEBULA CERTIFICATE" + X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" + X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" + Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" + Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" +) + +type NebulaCertificate struct { + Details NebulaCertificateDetails + Signature []byte +} + +type NebulaCertificateDetails struct { + Name string + Ips []*net.IPNet + Subnets []*net.IPNet + Groups []string + NotBefore time.Time + NotAfter time.Time + PublicKey []byte + IsCA bool + Issuer string + + // Map of groups for faster lookup + InvertedGroups map[string]struct{} +} + +type m map[string]interface{} + +// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert +func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) { + if len(b) == 0 { + return nil, fmt.Errorf("nil byte array") + } + var rc RawNebulaCertificate + err := proto.Unmarshal(b, &rc) + if err != nil { + return nil, err + } + + if len(rc.Details.Ips)%2 != 0 { + return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found") + } + + if len(rc.Details.Subnets)%2 != 0 { + return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found") + } + + nc := NebulaCertificate{ + Details: NebulaCertificateDetails{ + Name: rc.Details.Name, + Groups: make([]string, len(rc.Details.Groups)), + Ips: make([]*net.IPNet, len(rc.Details.Ips)/2), + Subnets: make([]*net.IPNet, len(rc.Details.Subnets)/2), + NotBefore: time.Unix(rc.Details.NotBefore, 0), + NotAfter: time.Unix(rc.Details.NotAfter, 0), + PublicKey: make([]byte, len(rc.Details.PublicKey)), + IsCA: rc.Details.IsCA, + InvertedGroups: make(map[string]struct{}), + }, + Signature: make([]byte, len(rc.Signature)), + } + + copy(nc.Signature, rc.Signature) + copy(nc.Details.Groups, rc.Details.Groups) + nc.Details.Issuer = hex.EncodeToString(rc.Details.Issuer) + + if len(rc.Details.PublicKey) < publicKeyLen { + return nil, fmt.Errorf("Public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey)) + } + copy(nc.Details.PublicKey, rc.Details.PublicKey) + + for i, rawIp := range rc.Details.Ips { + if i%2 == 0 { + nc.Details.Ips[i/2] = &net.IPNet{IP: int2ip(rawIp)} + } else { + nc.Details.Ips[i/2].Mask = net.IPMask(int2ip(rawIp)) + } + } + + for i, rawIp := range rc.Details.Subnets { + if i%2 == 0 { + nc.Details.Subnets[i/2] = &net.IPNet{IP: int2ip(rawIp)} + } else { + nc.Details.Subnets[i/2].Mask = net.IPMask(int2ip(rawIp)) + } + } + + for _, g := range rc.Details.Groups { + nc.Details.InvertedGroups[g] = struct{}{} + } + + return &nc, nil +} + +// UnmarshalNebulaCertificateFromPEM will unmarshal the first pem block in a byte array, returning any non consumed data +// or an error on failure +func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) { + p, r := pem.Decode(b) + if p == nil { + return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") + } + nc, err := UnmarshalNebulaCertificate(p.Bytes) + return nc, r, err +} + +// MarshalX25519PrivateKey is a simple helper to PEM encode an X25519 private key +func MarshalX25519PrivateKey(b []byte) []byte { + return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b}) +} + +// MarshalEd25519PrivateKey is a simple helper to PEM encode an Ed25519 private key +func MarshalEd25519PrivateKey(key ed25519.PrivateKey) []byte { + return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: key}) +} + +// UnmarshalX25519PrivateKey will try to pem decode an X25519 private key, returning any other bytes b +// or an error on failure +func UnmarshalX25519PrivateKey(b []byte) ([]byte, []byte, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") + } + if k.Type != X25519PrivateKeyBanner { + return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 private key banner") + } + if len(k.Bytes) != publicKeyLen { + return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 private key") + } + + return k.Bytes, r, nil +} + +// UnmarshalEd25519PrivateKey will try to pem decode an Ed25519 private key, returning any other bytes b +// or an error on failure +func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") + } + if k.Type != Ed25519PrivateKeyBanner { + return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 private key banner") + } + if len(k.Bytes) != ed25519.PrivateKeySize { + return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") + } + + return k.Bytes, r, nil +} + +// MarshalX25519PublicKey is a simple helper to PEM encode an X25519 public key +func MarshalX25519PublicKey(b []byte) []byte { + return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) +} + +// MarshalEd25519PublicKey is a simple helper to PEM encode an Ed25519 public key +func MarshalEd25519PublicKey(key ed25519.PublicKey) []byte { + return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: key}) +} + +// UnmarshalX25519PublicKey will try to pem decode an X25519 public key, returning any other bytes b +// or an error on failure +func UnmarshalX25519PublicKey(b []byte) ([]byte, []byte, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") + } + if k.Type != X25519PublicKeyBanner { + return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 public key banner") + } + if len(k.Bytes) != publicKeyLen { + return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 public key") + } + + return k.Bytes, r, nil +} + +// UnmarshalEd25519PublicKey will try to pem decode an Ed25519 public key, returning any other bytes b +// or an error on failure +func UnmarshalEd25519PublicKey(b []byte) (ed25519.PublicKey, []byte, error) { + k, r := pem.Decode(b) + if k == nil { + return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") + } + if k.Type != Ed25519PublicKeyBanner { + return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 public key banner") + } + if len(k.Bytes) != ed25519.PublicKeySize { + return nil, r, fmt.Errorf("key was not 32 bytes, is invalid ed25519 public key") + } + + return k.Bytes, r, nil +} + +// Sign signs a nebula cert with the provided private key +func (nc *NebulaCertificate) Sign(key ed25519.PrivateKey) error { + b, err := proto.Marshal(nc.getRawDetails()) + if err != nil { + return err + } + + sig, err := key.Sign(rand.Reader, b, crypto.Hash(0)) + if err != nil { + return err + } + nc.Signature = sig + return nil +} + +// CheckSignature verifies the signature against the provided public key +func (nc *NebulaCertificate) CheckSignature(key ed25519.PublicKey) bool { + b, err := proto.Marshal(nc.getRawDetails()) + if err != nil { + return false + } + return ed25519.Verify(key, b, nc.Signature) +} + +// Expired will return true if the nebula cert is too young or too old compared to the provided time, otherwise false +func (nc *NebulaCertificate) Expired(t time.Time) bool { + return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t) +} + +// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blacklist, etc) +func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) { + if ncp.IsBlacklisted(nc) { + return false, fmt.Errorf("certificate has been blacklisted") + } + + signer, err := ncp.GetCAForCert(nc) + if err != nil { + return false, err + } + + if signer.Expired(t) { + return false, fmt.Errorf("root certificate is expired") + } + + if nc.Expired(t) { + return false, fmt.Errorf("certificate is expired") + } + + if !nc.CheckSignature(signer.Details.PublicKey) { + return false, fmt.Errorf("certificate signature did not match") + } + + // If the signer has a limited set of groups make sure the cert only contains a subset + if len(signer.Details.InvertedGroups) > 0 { + for _, g := range nc.Details.Groups { + if _, ok := signer.Details.InvertedGroups[g]; !ok { + return false, fmt.Errorf("certificate contained a group not present on the signing ca; %s", g) + } + } + } + + return true, nil +} + +// VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match +func (nc *NebulaCertificate) VerifyPrivateKey(key []byte) error { + var dst, key32 [32]byte + copy(key32[:], key) + curve25519.ScalarBaseMult(&dst, &key32) + if !bytes.Equal(dst[:], nc.Details.PublicKey) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + return nil +} + +// String will return a pretty printed representation of a nebula cert +func (nc *NebulaCertificate) String() string { + if nc == nil { + return "NebulaCertificate {}\n" + } + + s := "NebulaCertificate {\n" + s += "\tDetails {\n" + s += fmt.Sprintf("\t\tName: %v\n", nc.Details.Name) + + if len(nc.Details.Ips) > 0 { + s += "\t\tIps: [\n" + for _, ip := range nc.Details.Ips { + s += fmt.Sprintf("\t\t\t%v\n", ip.String()) + } + s += "\t\t]\n" + } else { + s += "\t\tIps: []\n" + } + + if len(nc.Details.Subnets) > 0 { + s += "\t\tSubnets: [\n" + for _, ip := range nc.Details.Subnets { + s += fmt.Sprintf("\t\t\t%v\n", ip.String()) + } + s += "\t\t]\n" + } else { + s += "\t\tSubnets: []\n" + } + + if len(nc.Details.Groups) > 0 { + s += "\t\tGroups: [\n" + for _, g := range nc.Details.Groups { + s += fmt.Sprintf("\t\t\t\"%v\"\n", g) + } + s += "\t\t]\n" + } else { + s += "\t\tGroups: []\n" + } + + s += fmt.Sprintf("\t\tNot before: %v\n", nc.Details.NotBefore) + s += fmt.Sprintf("\t\tNot After: %v\n", nc.Details.NotAfter) + s += fmt.Sprintf("\t\tIs CA: %v\n", nc.Details.IsCA) + s += fmt.Sprintf("\t\tIssuer: %s\n", nc.Details.Issuer) + s += fmt.Sprintf("\t\tPublic key: %x\n", nc.Details.PublicKey) + s += "\t}\n" + fp, err := nc.Sha256Sum() + if err == nil { + s += fmt.Sprintf("\tFingerprint: %s\n", fp) + } + s += fmt.Sprintf("\tSignature: %x\n", nc.Signature) + s += "}" + + return s +} + +// getRawDetails marshals the raw details into protobuf ready struct +func (nc *NebulaCertificate) getRawDetails() *RawNebulaCertificateDetails { + rd := &RawNebulaCertificateDetails{ + Name: nc.Details.Name, + Groups: nc.Details.Groups, + NotBefore: nc.Details.NotBefore.Unix(), + NotAfter: nc.Details.NotAfter.Unix(), + PublicKey: make([]byte, len(nc.Details.PublicKey)), + IsCA: nc.Details.IsCA, + } + + for _, ipNet := range nc.Details.Ips { + rd.Ips = append(rd.Ips, ip2int(ipNet.IP), ip2int(ipNet.Mask)) + } + + for _, ipNet := range nc.Details.Subnets { + rd.Subnets = append(rd.Subnets, ip2int(ipNet.IP), ip2int(ipNet.Mask)) + } + + copy(rd.PublicKey, nc.Details.PublicKey[:]) + + // I know, this is terrible + rd.Issuer, _ = hex.DecodeString(nc.Details.Issuer) + + return rd +} + +// Marshal will marshal a nebula cert into a protobuf byte array +func (nc *NebulaCertificate) Marshal() ([]byte, error) { + rc := RawNebulaCertificate{ + Details: nc.getRawDetails(), + Signature: nc.Signature, + } + + return proto.Marshal(&rc) +} + +// MarshalToPEM will marshal a nebula cert into a protobuf byte array and pem encode the result +func (nc *NebulaCertificate) MarshalToPEM() ([]byte, error) { + b, err := nc.Marshal() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: CertBanner, Bytes: b}), nil +} + +// Sha256Sum calculates a sha-256 sum of the marshaled certificate +func (nc *NebulaCertificate) Sha256Sum() (string, error) { + b, err := nc.Marshal() + if err != nil { + return "", err + } + + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]), nil +} + +func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) { + toString := func(ips []*net.IPNet) []string { + s := []string{} + for _, ip := range ips { + s = append(s, ip.String()) + } + return s + } + + fp, _ := nc.Sha256Sum() + jc := m{ + "details": m{ + "name": nc.Details.Name, + "ips": toString(nc.Details.Ips), + "subnets": toString(nc.Details.Subnets), + "groups": nc.Details.Groups, + "notBefore": nc.Details.NotBefore, + "notAfter": nc.Details.NotAfter, + "publicKey": fmt.Sprintf("%x", nc.Details.PublicKey), + "isCa": nc.Details.IsCA, + "issuer": nc.Details.Issuer, + }, + "fingerprint": fp, + "signature": fmt.Sprintf("%x", nc.Signature), + } + return json.Marshal(jc) +} + +func ip2int(ip []byte) uint32 { + if len(ip) == 16 { + return binary.BigEndian.Uint32(ip[12:16]) + } + return binary.BigEndian.Uint32(ip) +} + +func int2ip(nn uint32) net.IP { + ip := make(net.IP, net.IPv4len) + binary.BigEndian.PutUint32(ip, nn) + return ip +} diff --git a/cert/cert.pb.go b/cert/cert.pb.go new file mode 100644 index 0000000..02c0826 --- /dev/null +++ b/cert/cert.pb.go @@ -0,0 +1,202 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: cert.proto + +package cert + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type RawNebulaCertificate struct { + Details *RawNebulaCertificateDetails `protobuf:"bytes,1,opt,name=Details,json=details,proto3" json:"Details,omitempty"` + Signature []byte `protobuf:"bytes,2,opt,name=Signature,json=signature,proto3" json:"Signature,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *RawNebulaCertificate) Reset() { *m = RawNebulaCertificate{} } +func (m *RawNebulaCertificate) String() string { return proto.CompactTextString(m) } +func (*RawNebulaCertificate) ProtoMessage() {} +func (*RawNebulaCertificate) Descriptor() ([]byte, []int) { + return fileDescriptor_a142e29cbef9b1cf, []int{0} +} + +func (m *RawNebulaCertificate) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_RawNebulaCertificate.Unmarshal(m, b) +} +func (m *RawNebulaCertificate) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_RawNebulaCertificate.Marshal(b, m, deterministic) +} +func (m *RawNebulaCertificate) XXX_Merge(src proto.Message) { + xxx_messageInfo_RawNebulaCertificate.Merge(m, src) +} +func (m *RawNebulaCertificate) XXX_Size() int { + return xxx_messageInfo_RawNebulaCertificate.Size(m) +} +func (m *RawNebulaCertificate) XXX_DiscardUnknown() { + xxx_messageInfo_RawNebulaCertificate.DiscardUnknown(m) +} + +var xxx_messageInfo_RawNebulaCertificate proto.InternalMessageInfo + +func (m *RawNebulaCertificate) GetDetails() *RawNebulaCertificateDetails { + if m != nil { + return m.Details + } + return nil +} + +func (m *RawNebulaCertificate) GetSignature() []byte { + if m != nil { + return m.Signature + } + return nil +} + +type RawNebulaCertificateDetails struct { + Name string `protobuf:"bytes,1,opt,name=Name,json=name,proto3" json:"Name,omitempty"` + // Ips and Subnets are in big endian 32 bit pairs, 1st the ip, 2nd the mask + Ips []uint32 `protobuf:"varint,2,rep,packed,name=Ips,json=ips,proto3" json:"Ips,omitempty"` + Subnets []uint32 `protobuf:"varint,3,rep,packed,name=Subnets,json=subnets,proto3" json:"Subnets,omitempty"` + Groups []string `protobuf:"bytes,4,rep,name=Groups,json=groups,proto3" json:"Groups,omitempty"` + NotBefore int64 `protobuf:"varint,5,opt,name=NotBefore,json=notBefore,proto3" json:"NotBefore,omitempty"` + NotAfter int64 `protobuf:"varint,6,opt,name=NotAfter,json=notAfter,proto3" json:"NotAfter,omitempty"` + PublicKey []byte `protobuf:"bytes,7,opt,name=PublicKey,json=publicKey,proto3" json:"PublicKey,omitempty"` + IsCA bool `protobuf:"varint,8,opt,name=IsCA,json=isCA,proto3" json:"IsCA,omitempty"` + // sha-256 of the issuer certificate, if this field is blank the cert is self-signed + Issuer []byte `protobuf:"bytes,9,opt,name=Issuer,json=issuer,proto3" json:"Issuer,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *RawNebulaCertificateDetails) Reset() { *m = RawNebulaCertificateDetails{} } +func (m *RawNebulaCertificateDetails) String() string { return proto.CompactTextString(m) } +func (*RawNebulaCertificateDetails) ProtoMessage() {} +func (*RawNebulaCertificateDetails) Descriptor() ([]byte, []int) { + return fileDescriptor_a142e29cbef9b1cf, []int{1} +} + +func (m *RawNebulaCertificateDetails) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_RawNebulaCertificateDetails.Unmarshal(m, b) +} +func (m *RawNebulaCertificateDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_RawNebulaCertificateDetails.Marshal(b, m, deterministic) +} +func (m *RawNebulaCertificateDetails) XXX_Merge(src proto.Message) { + xxx_messageInfo_RawNebulaCertificateDetails.Merge(m, src) +} +func (m *RawNebulaCertificateDetails) XXX_Size() int { + return xxx_messageInfo_RawNebulaCertificateDetails.Size(m) +} +func (m *RawNebulaCertificateDetails) XXX_DiscardUnknown() { + xxx_messageInfo_RawNebulaCertificateDetails.DiscardUnknown(m) +} + +var xxx_messageInfo_RawNebulaCertificateDetails proto.InternalMessageInfo + +func (m *RawNebulaCertificateDetails) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *RawNebulaCertificateDetails) GetIps() []uint32 { + if m != nil { + return m.Ips + } + return nil +} + +func (m *RawNebulaCertificateDetails) GetSubnets() []uint32 { + if m != nil { + return m.Subnets + } + return nil +} + +func (m *RawNebulaCertificateDetails) GetGroups() []string { + if m != nil { + return m.Groups + } + return nil +} + +func (m *RawNebulaCertificateDetails) GetNotBefore() int64 { + if m != nil { + return m.NotBefore + } + return 0 +} + +func (m *RawNebulaCertificateDetails) GetNotAfter() int64 { + if m != nil { + return m.NotAfter + } + return 0 +} + +func (m *RawNebulaCertificateDetails) GetPublicKey() []byte { + if m != nil { + return m.PublicKey + } + return nil +} + +func (m *RawNebulaCertificateDetails) GetIsCA() bool { + if m != nil { + return m.IsCA + } + return false +} + +func (m *RawNebulaCertificateDetails) GetIssuer() []byte { + if m != nil { + return m.Issuer + } + return nil +} + +func init() { + proto.RegisterType((*RawNebulaCertificate)(nil), "cert.RawNebulaCertificate") + proto.RegisterType((*RawNebulaCertificateDetails)(nil), "cert.RawNebulaCertificateDetails") +} + +func init() { proto.RegisterFile("cert.proto", fileDescriptor_a142e29cbef9b1cf) } + +var fileDescriptor_a142e29cbef9b1cf = []byte{ + // 279 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x90, 0xcf, 0x4a, 0xf4, 0x30, + 0x14, 0xc5, 0xc9, 0xa4, 0x5f, 0xdb, 0xe4, 0x53, 0x90, 0x20, 0x12, 0xd4, 0x45, 0x9c, 0x55, 0x56, + 0xb3, 0xd0, 0xa5, 0xab, 0x71, 0x04, 0x29, 0x42, 0x91, 0xcc, 0x13, 0xa4, 0xf5, 0x76, 0x08, 0x74, + 0x9a, 0x9a, 0x3f, 0x88, 0x8f, 0xee, 0x4e, 0x9a, 0x4e, 0x77, 0xe2, 0xee, 0x9e, 0x5f, 0xce, 0x49, + 0x4e, 0x2e, 0xa5, 0x2d, 0xb8, 0xb0, 0x19, 0x9d, 0x0d, 0x96, 0x65, 0xd3, 0xbc, 0xfe, 0xa0, 0x97, + 0x4a, 0x7f, 0xd6, 0xd0, 0xc4, 0x5e, 0xef, 0xc0, 0x05, 0xd3, 0x99, 0x56, 0x07, 0x60, 0x8f, 0xb4, + 0x78, 0x86, 0xa0, 0x4d, 0xef, 0x39, 0x12, 0x48, 0xfe, 0xbf, 0xbf, 0xdb, 0xa4, 0xec, 0x6f, 0xe6, + 0x93, 0x51, 0x15, 0xef, 0xf3, 0xc0, 0x6e, 0x29, 0xd9, 0x9b, 0xc3, 0xa0, 0x43, 0x74, 0xc0, 0x57, + 0x02, 0xc9, 0x33, 0x45, 0xfc, 0x02, 0xd6, 0xdf, 0x88, 0xde, 0xfc, 0x71, 0x0d, 0x63, 0x34, 0xab, + 0xf5, 0x11, 0xd2, 0xbb, 0x44, 0x65, 0x83, 0x3e, 0x02, 0xbb, 0xa0, 0xb8, 0x1a, 0x3d, 0x5f, 0x09, + 0x2c, 0xcf, 0x15, 0x36, 0xa3, 0x67, 0x9c, 0x16, 0xfb, 0xd8, 0x0c, 0x10, 0x3c, 0xc7, 0x89, 0x16, + 0x7e, 0x96, 0xec, 0x8a, 0xe6, 0x2f, 0xce, 0xc6, 0xd1, 0xf3, 0x4c, 0x60, 0x49, 0x54, 0x7e, 0x48, + 0x6a, 0x6a, 0x55, 0xdb, 0xf0, 0x04, 0x9d, 0x75, 0xc0, 0xff, 0x09, 0x24, 0xb1, 0x22, 0xc3, 0x02, + 0xd8, 0x35, 0x2d, 0x6b, 0x1b, 0xb6, 0x5d, 0x00, 0xc7, 0xf3, 0x74, 0x58, 0x0e, 0x27, 0x3d, 0x25, + 0xdf, 0x62, 0xd3, 0x9b, 0xf6, 0x15, 0xbe, 0x78, 0x31, 0xff, 0x67, 0x5c, 0xc0, 0xd4, 0xb7, 0xf2, + 0xbb, 0x2d, 0x2f, 0x05, 0x92, 0xa5, 0xca, 0x8c, 0xdf, 0x6d, 0xa7, 0x0e, 0x95, 0xf7, 0x11, 0x1c, + 0x27, 0xc9, 0x9e, 0x9b, 0xa4, 0x9a, 0x3c, 0xed, 0xfe, 0xe1, 0x27, 0x00, 0x00, 0xff, 0xff, 0x2c, + 0xe3, 0x08, 0x37, 0x89, 0x01, 0x00, 0x00, +} diff --git a/cert/cert.proto b/cert/cert.proto new file mode 100644 index 0000000..29977c4 --- /dev/null +++ b/cert/cert.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; +package cert; + +//import "google/protobuf/timestamp.proto"; + +message RawNebulaCertificate { + RawNebulaCertificateDetails Details = 1; + bytes Signature = 2; +} + +message RawNebulaCertificateDetails { + string Name = 1; + + // Ips and Subnets are in big endian 32 bit pairs, 1st the ip, 2nd the mask + repeated uint32 Ips = 2; + repeated uint32 Subnets = 3; + + repeated string Groups = 4; + int64 NotBefore = 5; + int64 NotAfter = 6; + bytes PublicKey = 7; + + bool IsCA = 8; + + // sha-256 of the issuer certificate, if this field is blank the cert is self-signed + bytes Issuer = 9; +} \ No newline at end of file diff --git a/cert/cert_test.go b/cert/cert_test.go new file mode 100644 index 0000000..3f99b41 --- /dev/null +++ b/cert/cert_test.go @@ -0,0 +1,373 @@ +package cert + +import ( + "crypto/rand" + "fmt" + "io" + "net" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ed25519" +) + +func TestMarshalingNebulaCertificate(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := NebulaCertificate{ + Details: NebulaCertificateDetails{ + Name: "testing", + Ips: []*net.IPNet{ + {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, + {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + }, + Subnets: []*net.IPNet{ + {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, + {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: false, + Issuer: "1234567890abcedfghij1234567890ab", + }, + Signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.Marshal() + assert.Nil(t, err) + t.Log("Cert size:", len(b)) + + nc2, err := UnmarshalNebulaCertificate(b) + assert.Nil(t, err) + + assert.Equal(t, nc.Signature, nc2.Signature) + assert.Equal(t, nc.Details.Name, nc2.Details.Name) + assert.Equal(t, nc.Details.NotBefore, nc2.Details.NotBefore) + assert.Equal(t, nc.Details.NotAfter, nc2.Details.NotAfter) + assert.Equal(t, nc.Details.PublicKey, nc2.Details.PublicKey) + assert.Equal(t, nc.Details.IsCA, nc2.Details.IsCA) + + // IP byte arrays can be 4 or 16 in length so we have to go this route + assert.Equal(t, len(nc.Details.Ips), len(nc2.Details.Ips)) + for i, wIp := range nc.Details.Ips { + assert.Equal(t, wIp.String(), nc2.Details.Ips[i].String()) + } + + assert.Equal(t, len(nc.Details.Subnets), len(nc2.Details.Subnets)) + for i, wIp := range nc.Details.Subnets { + assert.Equal(t, wIp.String(), nc2.Details.Subnets[i].String()) + } + + assert.EqualValues(t, nc.Details.Groups, nc2.Details.Groups) +} + +func TestNebulaCertificate_Sign(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := NebulaCertificate{ + Details: NebulaCertificateDetails{ + Name: "testing", + Ips: []*net.IPNet{ + {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, + {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + }, + Subnets: []*net.IPNet{ + {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, + {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: false, + Issuer: "1234567890abcedfghij1234567890ab", + }, + } + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + assert.Nil(t, err) + assert.False(t, nc.CheckSignature(pub)) + assert.Nil(t, nc.Sign(priv)) + assert.True(t, nc.CheckSignature(pub)) + + b, err := nc.Marshal() + assert.Nil(t, err) + t.Log("Cert size:", len(b)) +} + +func TestNebulaCertificate_Expired(t *testing.T) { + nc := NebulaCertificate{ + Details: NebulaCertificateDetails{ + NotBefore: time.Now().Add(time.Second * -60).Round(time.Second), + NotAfter: time.Now().Add(time.Second * 60).Round(time.Second), + }, + } + + assert.True(t, nc.Expired(time.Now().Add(time.Hour))) + assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) + assert.False(t, nc.Expired(time.Now())) +} + +func TestNebulaCertificate_MarshalJSON(t *testing.T) { + time.Local = time.UTC + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := NebulaCertificate{ + Details: NebulaCertificateDetails{ + Name: "testing", + Ips: []*net.IPNet{ + {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, + {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + }, + Subnets: []*net.IPNet{ + {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, + {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), + NotAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), + PublicKey: pubKey, + IsCA: false, + Issuer: "1234567890abcedfghij1234567890ab", + }, + Signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.MarshalJSON() + assert.Nil(t, err) + assert.Equal( + t, + "{\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\",\"10.1.1.3/ff00ff00\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.1/ff00ff00\",\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"26cb1c30ad7872c804c166b5150fa372f437aa3856b04edb4334b4470ec728e4\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}", + string(b), + ) +} + +func TestNebulaCertificate_Verify(t *testing.T) { + ca, _, caKey, err := newTestCaCert() + assert.Nil(t, err) + + c, _, _, err := newTestCert(ca, caKey) + assert.Nil(t, err) + + h, err := ca.Sha256Sum() + assert.Nil(t, err) + + caPool := NewCAPool() + caPool.CAs[h] = ca + + f, err := c.Sha256Sum() + assert.Nil(t, err) + caPool.BlacklistFingerprint(f) + + v, err := c.Verify(time.Now(), caPool) + assert.False(t, v) + assert.EqualError(t, err, "certificate has been blacklisted") + + caPool.ResetCertBlacklist() + v, err = c.Verify(time.Now(), caPool) + assert.True(t, v) + assert.Nil(t, err) + + v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool) + assert.False(t, v) + assert.EqualError(t, err, "root certificate is expired") +} + +func TestNebulaVerifyPrivateKey(t *testing.T) { + ca, _, caKey, err := newTestCaCert() + assert.Nil(t, err) + + c, _, priv, err := newTestCert(ca, caKey) + err = c.VerifyPrivateKey(priv) + assert.Nil(t, err) + + _, priv2 := x25519Keypair() + err = c.VerifyPrivateKey(priv2) + assert.NotNil(t, err) +} + +func TestNewCAPoolFromBytes(t *testing.T) { + noNewLines := ` +# Current provisional, Remove once everything moves over to the real root. +-----BEGIN NEBULA CERTIFICATE----- +CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL +vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv +bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +-----END NEBULA CERTIFICATE----- +# root-ca01 +-----BEGIN NEBULA CERTIFICATE----- +CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG +BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf +8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF +-----END NEBULA CERTIFICATE----- +` + + withNewLines := ` +# Current provisional, Remove once everything moves over to the real root. + +-----BEGIN NEBULA CERTIFICATE----- +CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL +vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv +bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +-----END NEBULA CERTIFICATE----- + +# root-ca01 + + +-----BEGIN NEBULA CERTIFICATE----- +CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG +BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf +8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF +-----END NEBULA CERTIFICATE----- + +` + + rootCA := NebulaCertificate{ + Details: NebulaCertificateDetails{ + Name: "nebula root ca", + }, + } + + rootCA01 := NebulaCertificate{ + Details: NebulaCertificateDetails{ + Name: "nebula root ca 01", + }, + } + + p, err := NewCAPoolFromBytes([]byte(noNewLines)) + assert.Nil(t, err) + assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) + assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) + + pp, err := NewCAPoolFromBytes([]byte(withNewLines)) + assert.Nil(t, err) + assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) + assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) +} + +// Ensure that upgrading the protobuf library does not change how certificates +// are marshalled, since this would break signature verification +func TestMarshalingNebulaCertificateConsistency(t *testing.T) { + before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) + after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := NebulaCertificate{ + Details: NebulaCertificateDetails{ + Name: "testing", + Ips: []*net.IPNet{ + {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, + {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + }, + Subnets: []*net.IPNet{ + {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, + {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: false, + Issuer: "1234567890abcedfghij1234567890ab", + }, + Signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.Marshal() + assert.Nil(t, err) + t.Log("Cert size:", len(b)) + assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) + + b, err = proto.Marshal(nc.getRawDetails()) + assert.Nil(t, err) + t.Log("Raw cert size:", len(b)) + assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) +} + +func newTestCaCert() (*NebulaCertificate, []byte, []byte, error) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + + nc := &NebulaCertificate{ + Details: NebulaCertificateDetails{ + Name: "test ca", + NotBefore: before, + NotAfter: after, + PublicKey: pub, + IsCA: true, + }, + } + + err = nc.Sign(priv) + if err != nil { + return nil, nil, nil, err + } + return nc, pub, priv, nil +} + +func newTestCert(ca *NebulaCertificate, key []byte) (*NebulaCertificate, []byte, []byte, error) { + issuer, err := ca.Sha256Sum() + if err != nil { + return nil, nil, nil, err + } + + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pub, rawPriv := x25519Keypair() + + nc := &NebulaCertificate{ + Details: NebulaCertificateDetails{ + Name: "testing", + Ips: []*net.IPNet{ + {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, + {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + }, + Subnets: []*net.IPNet{ + {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, + {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pub, + IsCA: false, + Issuer: issuer, + }, + } + + err = nc.Sign(key) + if err != nil { + return nil, nil, nil, err + } + + return nc, pub, rawPriv, nil +} + +func x25519Keypair() ([]byte, []byte) { + var pubkey, privkey [32]byte + if _, err := io.ReadFull(rand.Reader, privkey[:]); err != nil { + panic(err) + } + curve25519.ScalarBaseMult(&pubkey, &privkey) + return pubkey[:], privkey[:] +} diff --git a/cidr_radix.go b/cidr_radix.go new file mode 100644 index 0000000..7726b9a --- /dev/null +++ b/cidr_radix.go @@ -0,0 +1,147 @@ +package nebula + +import ( + "encoding/binary" + "fmt" + "net" +) + +type CIDRNode struct { + left *CIDRNode + right *CIDRNode + parent *CIDRNode + value interface{} +} + +type CIDRTree struct { + root *CIDRNode +} + +const ( + startbit = uint32(0x80000000) +) + +func NewCIDRTree() *CIDRTree { + tree := new(CIDRTree) + tree.root = &CIDRNode{} + return tree +} + +func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) { + bit := startbit + node := tree.root + next := tree.root + + ip := ip2int(cidr.IP) + mask := ip2int(cidr.Mask) + + // Find our last ancestor in the tree + for bit&mask != 0 { + if ip&bit != 0 { + next = node.right + } else { + next = node.left + } + + if next == nil { + break + } + + bit = bit >> 1 + node = next + } + + // We already have this range so update the value + if next != nil { + node.value = val + return + } + + // Build up the rest of the tree we don't already have + for bit&mask != 0 { + next = &CIDRNode{} + next.parent = node + + if ip&bit != 0 { + node.right = next + } else { + node.left = next + } + + bit >>= 1 + node = next + } + + // Final node marks our cidr, set the value + node.value = val +} + +// Finds the first match, which way be the least specific +func (tree *CIDRTree) Contains(ip uint32) (value interface{}) { + bit := startbit + node := tree.root + + for node != nil { + if node.value != nil { + return node.value + } + + if ip&bit != 0 { + node = node.right + } else { + node = node.left + } + + bit >>= 1 + + } + + return value +} + +// Finds the most specific match +func (tree *CIDRTree) Match(ip uint32) (value interface{}) { + bit := startbit + node := tree.root + lastNode := node + + for node != nil { + lastNode = node + if ip&bit != 0 { + node = node.right + } else { + node = node.left + } + + bit >>= 1 + } + + if bit == 0 && lastNode != nil { + value = lastNode.value + } + return value +} + +// A helper type to avoid converting to IP when logging +type IntIp uint32 + +func (ip IntIp) String() string { + return fmt.Sprintf("%v", int2ip(uint32(ip))) +} + +func (ip IntIp) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf("\"%s\"", int2ip(uint32(ip)).String())), nil +} + +func ip2int(ip []byte) uint32 { + if len(ip) == 16 { + return binary.BigEndian.Uint32(ip[12:16]) + } + return binary.BigEndian.Uint32(ip) +} + +func int2ip(nn uint32) net.IP { + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, nn) + return ip +} diff --git a/cidr_radix_test.go b/cidr_radix_test.go new file mode 100644 index 0000000..e7461bd --- /dev/null +++ b/cidr_radix_test.go @@ -0,0 +1,118 @@ +package nebula + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCIDRTree_Contains(t *testing.T) { + tree := NewCIDRTree() + tree.AddCIDR(getCIDR("1.0.0.0/8"), "1") + tree.AddCIDR(getCIDR("2.1.0.0/16"), "2") + tree.AddCIDR(getCIDR("3.1.1.0/24"), "3") + tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a") + tree.AddCIDR(getCIDR("4.1.1.1/32"), "4b") + tree.AddCIDR(getCIDR("4.1.2.1/32"), "4c") + tree.AddCIDR(getCIDR("254.0.0.0/4"), "5") + + tests := []struct { + Result interface{} + IP string + }{ + {"1", "1.0.0.0"}, + {"1", "1.255.255.255"}, + {"2", "2.1.0.0"}, + {"2", "2.1.255.255"}, + {"3", "3.1.1.0"}, + {"3", "3.1.1.255"}, + {"4a", "4.1.1.255"}, + {"4a", "4.1.1.1"}, + {"5", "240.0.0.0"}, + {"5", "255.255.255.255"}, + {nil, "239.0.0.0"}, + {nil, "4.1.2.2"}, + } + + for _, tt := range tests { + assert.Equal(t, tt.Result, tree.Contains(ip2int(net.ParseIP(tt.IP)))) + } + + tree = NewCIDRTree() + tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool") + assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0")))) + assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255")))) +} + +func TestCIDRTree_Match(t *testing.T) { + tree := NewCIDRTree() + tree.AddCIDR(getCIDR("4.1.1.0/32"), "1a") + tree.AddCIDR(getCIDR("4.1.1.1/32"), "1b") + + tests := []struct { + Result interface{} + IP string + }{ + {"1a", "4.1.1.0"}, + {"1b", "4.1.1.1"}, + } + + for _, tt := range tests { + assert.Equal(t, tt.Result, tree.Match(ip2int(net.ParseIP(tt.IP)))) + } + + tree = NewCIDRTree() + tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool") + assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0")))) + assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255")))) +} + +func BenchmarkCIDRTree_Contains(b *testing.B) { + tree := NewCIDRTree() + tree.AddCIDR(getCIDR("1.1.0.0/16"), "1") + tree.AddCIDR(getCIDR("1.2.1.1/32"), "1") + tree.AddCIDR(getCIDR("192.2.1.1/32"), "1") + tree.AddCIDR(getCIDR("172.2.1.1/32"), "1") + + ip := ip2int(net.ParseIP("1.2.1.1")) + b.Run("found", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tree.Contains(ip) + } + }) + + ip = ip2int(net.ParseIP("1.2.1.255")) + b.Run("not found", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tree.Contains(ip) + } + }) +} + +func BenchmarkCIDRTree_Match(b *testing.B) { + tree := NewCIDRTree() + tree.AddCIDR(getCIDR("1.1.0.0/16"), "1") + tree.AddCIDR(getCIDR("1.2.1.1/32"), "1") + tree.AddCIDR(getCIDR("192.2.1.1/32"), "1") + tree.AddCIDR(getCIDR("172.2.1.1/32"), "1") + + ip := ip2int(net.ParseIP("1.2.1.1")) + b.Run("found", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tree.Match(ip) + } + }) + + ip = ip2int(net.ParseIP("1.2.1.255")) + b.Run("not found", func(b *testing.B) { + for i := 0; i < b.N; i++ { + tree.Match(ip) + } + }) +} + +func getCIDR(s string) *net.IPNet { + _, c, _ := net.ParseCIDR(s) + return c +} diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go new file mode 100644 index 0000000..bb96455 --- /dev/null +++ b/cmd/nebula-cert/ca.go @@ -0,0 +1,124 @@ +package main + +import ( + "crypto/rand" + "flag" + "fmt" + "io" + "io/ioutil" + "os" + "strings" + "time" + + "golang.org/x/crypto/ed25519" + "github.com/slackhq/nebula/cert" +) + +type caFlags struct { + set *flag.FlagSet + name *string + duration *time.Duration + outKeyPath *string + outCertPath *string + groups *string +} + +func newCaFlags() *caFlags { + cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)} + cf.set.Usage = func() {} + cf.name = cf.set.String("name", "", "Required: name of the certificate authority") + cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") + cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to") + cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to") + cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use") + return &cf +} + +func ca(args []string, out io.Writer, errOut io.Writer) error { + cf := newCaFlags() + err := cf.set.Parse(args) + if err != nil { + return err + } + + if err := mustFlagString("name", cf.name); err != nil { + return err + } + if err := mustFlagString("out-key", cf.outKeyPath); err != nil { + return err + } + if err := mustFlagString("out-crt", cf.outCertPath); err != nil { + return err + } + + if *cf.duration <= 0 { + return &helpError{"-duration must be greater than 0"} + } + + groups := []string{} + if *cf.groups != "" { + for _, rg := range strings.Split(*cf.groups, ",") { + g := strings.TrimSpace(rg) + if g != "" { + groups = append(groups, g) + } + } + } + + pub, rawPriv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return fmt.Errorf("error while generating ed25519 keys: %s", err) + } + + nc := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: *cf.name, + Groups: groups, + NotBefore: time.Now(), + NotAfter: time.Now().Add(*cf.duration), + PublicKey: pub, + IsCA: true, + }, + } + + if _, err := os.Stat(*cf.outKeyPath); err == nil { + return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath) + } + + if _, err := os.Stat(*cf.outCertPath); err == nil { + return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) + } + + err = nc.Sign(rawPriv) + if err != nil { + return fmt.Errorf("error while signing: %s", err) + } + + err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalEd25519PrivateKey(rawPriv), 0600) + if err != nil { + return fmt.Errorf("error while writing out-key: %s", err) + } + + b, err := nc.MarshalToPEM() + if err != nil { + return fmt.Errorf("error while marshalling certificate: %s", err) + } + + err = ioutil.WriteFile(*cf.outCertPath, b, 0600) + if err != nil { + return fmt.Errorf("error while writing out-crt: %s", err) + } + + return nil +} + +func caSummary() string { + return "ca : create a self signed certificate authority" +} + +func caHelp(out io.Writer) { + cf := newCaFlags() + out.Write([]byte("Usage of " + os.Args[0] + " " + caSummary() + "\n")) + cf.set.SetOutput(out) + cf.set.PrintDefaults() +} diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go new file mode 100644 index 0000000..4abd394 --- /dev/null +++ b/cmd/nebula-cert/ca_test.go @@ -0,0 +1,132 @@ +package main + +import ( + "bytes" + "io/ioutil" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/slackhq/nebula/cert" +) + +//TODO: test file permissions + +func Test_caSummary(t *testing.T) { + assert.Equal(t, "ca : create a self signed certificate authority", caSummary()) +} + +func Test_caHelp(t *testing.T) { + ob := &bytes.Buffer{} + caHelp(ob) + assert.Equal( + t, + "Usage of "+os.Args[0]+" ca : create a self signed certificate authority\n"+ + " -duration duration\n"+ + " \tOptional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\" (default 8760h0m0s)\n"+ + " -groups string\n"+ + " \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+ + " -name string\n"+ + " \tRequired: name of the certificate authority\n"+ + " -out-crt string\n"+ + " \tOptional: path to write the certificate to (default \"ca.crt\")\n"+ + " -out-key string\n"+ + " \tOptional: path to write the private key to (default \"ca.key\")\n", + ob.String(), + ) +} + +func Test_ca(t *testing.T) { + ob := &bytes.Buffer{} + eb := &bytes.Buffer{} + + // required args + assertHelpError(t, ca([]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb), "-name is required") + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + + // failed key write + ob.Reset() + eb.Reset() + args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} + assert.EqualError(t, ca(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + + // create temp key file + keyF, err := ioutil.TempFile("", "test.key") + assert.Nil(t, err) + os.Remove(keyF.Name()) + + // failed cert write + ob.Reset() + eb.Reset() + args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} + assert.EqualError(t, ca(args, ob, eb), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + + // create temp cert file + crtF, err := ioutil.TempFile("", "test.crt") + assert.Nil(t, err) + os.Remove(crtF.Name()) + os.Remove(keyF.Name()) + + // test proper cert with removed empty groups and subnets + ob.Reset() + eb.Reset() + args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + assert.Nil(t, ca(args, ob, eb)) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + + // read cert and key files + rb, _ := ioutil.ReadFile(keyF.Name()) + lKey, b, err := cert.UnmarshalEd25519PrivateKey(rb) + assert.Len(t, b, 0) + assert.Nil(t, err) + assert.Len(t, lKey, 64) + + rb, _ = ioutil.ReadFile(crtF.Name()) + lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb) + assert.Len(t, b, 0) + assert.Nil(t, err) + + assert.Equal(t, "test", lCrt.Details.Name) + assert.Len(t, lCrt.Details.Ips, 0) + assert.True(t, lCrt.Details.IsCA) + assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups) + assert.Len(t, lCrt.Details.Subnets, 0) + assert.Len(t, lCrt.Details.PublicKey, 32) + assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore)) + assert.Equal(t, "", lCrt.Details.Issuer) + assert.True(t, lCrt.CheckSignature(lCrt.Details.PublicKey)) + + // create valid cert/key for overwrite tests + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + assert.Nil(t, ca(args, ob, eb)) + + // test that we won't overwrite existing certificate file + ob.Reset() + eb.Reset() + args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + assert.EqualError(t, ca(args, ob, eb), "refusing to overwrite existing CA key: "+keyF.Name()) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + + // test that we won't overwrite existing key file + os.Remove(keyF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + assert.EqualError(t, ca(args, ob, eb), "refusing to overwrite existing CA cert: "+crtF.Name()) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + os.Remove(keyF.Name()) + +} diff --git a/cmd/nebula-cert/keygen.go b/cmd/nebula-cert/keygen.go new file mode 100644 index 0000000..4f15af8 --- /dev/null +++ b/cmd/nebula-cert/keygen.go @@ -0,0 +1,65 @@ +package main + +import ( + "flag" + "fmt" + "io" + "io/ioutil" + "os" + + "github.com/slackhq/nebula/cert" +) + +type keygenFlags struct { + set *flag.FlagSet + outKeyPath *string + outPubPath *string +} + +func newKeygenFlags() *keygenFlags { + cf := keygenFlags{set: flag.NewFlagSet("keygen", flag.ContinueOnError)} + cf.set.Usage = func() {} + cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to") + cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to") + return &cf +} + +func keygen(args []string, out io.Writer, errOut io.Writer) error { + cf := newKeygenFlags() + err := cf.set.Parse(args) + if err != nil { + return err + } + + if err := mustFlagString("out-key", cf.outKeyPath); err != nil { + return err + } + if err := mustFlagString("out-pub", cf.outPubPath); err != nil { + return err + } + + pub, rawPriv := x25519Keypair() + + err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalX25519PrivateKey(rawPriv), 0600) + if err != nil { + return fmt.Errorf("error while writing out-key: %s", err) + } + + err = ioutil.WriteFile(*cf.outPubPath, cert.MarshalX25519PublicKey(pub), 0600) + if err != nil { + return fmt.Errorf("error while writing out-pub: %s", err) + } + + return nil +} + +func keygenSummary() string { + return "keygen : create a public/private key pair. the public key can be passed to `nebula-cert sign`" +} + +func keygenHelp(out io.Writer) { + cf := newKeygenFlags() + out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n")) + cf.set.SetOutput(out) + cf.set.PrintDefaults() +} diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go new file mode 100644 index 0000000..9f02a37 --- /dev/null +++ b/cmd/nebula-cert/keygen_test.go @@ -0,0 +1,92 @@ +package main + +import ( + "bytes" + "io/ioutil" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/slackhq/nebula/cert" +) + +//TODO: test file permissions + +func Test_keygenSummary(t *testing.T) { + assert.Equal(t, "keygen : create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary()) +} + +func Test_keygenHelp(t *testing.T) { + ob := &bytes.Buffer{} + keygenHelp(ob) + assert.Equal( + t, + "Usage of "+os.Args[0]+" keygen : create a public/private key pair. the public key can be passed to `nebula-cert sign`\n"+ + " -out-key string\n"+ + " \tRequired: path to write the private key to\n"+ + " -out-pub string\n"+ + " \tRequired: path to write the public key to\n", + ob.String(), + ) +} + +func Test_keygen(t *testing.T) { + ob := &bytes.Buffer{} + eb := &bytes.Buffer{} + + // required args + assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required") + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + + assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required") + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + + // failed key write + ob.Reset() + eb.Reset() + args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"} + assert.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + + // create temp key file + keyF, err := ioutil.TempFile("", "test.key") + assert.Nil(t, err) + defer os.Remove(keyF.Name()) + + // failed pub write + ob.Reset() + eb.Reset() + args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()} + assert.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + + // create temp pub file + pubF, err := ioutil.TempFile("", "test.pub") + assert.Nil(t, err) + defer os.Remove(pubF.Name()) + + // test proper keygen + ob.Reset() + eb.Reset() + args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()} + assert.Nil(t, keygen(args, ob, eb)) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + + // read cert and key files + rb, _ := ioutil.ReadFile(keyF.Name()) + lKey, b, err := cert.UnmarshalX25519PrivateKey(rb) + assert.Len(t, b, 0) + assert.Nil(t, err) + assert.Len(t, lKey, 32) + + rb, _ = ioutil.ReadFile(pubF.Name()) + lPub, b, err := cert.UnmarshalX25519PublicKey(rb) + assert.Len(t, b, 0) + assert.Nil(t, err) + assert.Len(t, lPub, 32) +} diff --git a/cmd/nebula-cert/main.go b/cmd/nebula-cert/main.go new file mode 100644 index 0000000..61d1cf4 --- /dev/null +++ b/cmd/nebula-cert/main.go @@ -0,0 +1,137 @@ +package main + +import ( + "flag" + "fmt" + "io" + "os" +) + +var Build string + +type helpError struct { + s string +} + +func (he *helpError) Error() string { + return he.s +} + +func newHelpErrorf(s string, v ...interface{}) error { + return &helpError{s: fmt.Sprintf(s, v...)} +} + +func main() { + flag.Usage = func() { + help("", os.Stderr) + os.Exit(1) + } + + printVersion := flag.Bool("version", false, "Print version") + flagHelp := flag.Bool("help", false, "Print command line usage") + flagH := flag.Bool("h", false, "Print command line usage") + printUsage := false + + flag.Parse() + + if *flagH || *flagHelp { + printUsage = true + } + + args := flag.Args() + + if *printVersion { + fmt.Printf("Version: %v", Build) + os.Exit(0) + } + + if len(args) < 1 { + if printUsage { + help("", os.Stderr) + os.Exit(0) + } + + help("No mode was provided", os.Stderr) + os.Exit(1) + } else if printUsage { + handleError(args[0], &helpError{}, os.Stderr) + os.Exit(0) + } + + var err error + + switch args[0] { + case "ca": + err = ca(args[1:], os.Stdout, os.Stderr) + case "keygen": + err = keygen(args[1:], os.Stdout, os.Stderr) + case "sign": + err = signCert(args[1:], os.Stdout, os.Stderr) + case "print": + err = printCert(args[1:], os.Stdout, os.Stderr) + case "verify": + err = verify(args[1:], os.Stdout, os.Stderr) + default: + err = fmt.Errorf("unknown mode: %s", args[0]) + } + + if err != nil { + os.Exit(handleError(args[0], err, os.Stderr)) + } +} + +func handleError(mode string, e error, out io.Writer) int { + code := 1 + + // Handle -help, -h flags properly + if e == flag.ErrHelp { + code = 0 + e = &helpError{} + } else if e != nil && e.Error() != "" { + fmt.Fprintln(out, "Error:", e) + } + + switch e.(type) { + case *helpError: + switch mode { + case "ca": + caHelp(out) + case "keygen": + keygenHelp(out) + case "sign": + signHelp(out) + case "print": + printHelp(out) + case "verify": + verifyHelp(out) + } + } + + return code +} + +func help(err string, out io.Writer) { + if err != "" { + fmt.Fprintln(out, "Error:", err) + fmt.Fprintln(out, "") + } + + fmt.Fprintf(out, "Usage of %s :\n", os.Args[0]) + fmt.Fprintln(out, " Global flags:") + fmt.Fprintln(out, " -version: Prints the version") + fmt.Fprintln(out, " -h, -help: Prints this help message") + fmt.Fprintln(out, "") + fmt.Fprintln(out, " Modes:") + fmt.Fprintln(out, " "+caSummary()) + fmt.Fprintln(out, " "+keygenSummary()) + fmt.Fprintln(out, " "+signSummary()) + fmt.Fprintln(out, " "+printSummary()) + fmt.Fprintln(out, " "+verifySummary()) +} + +func mustFlagString(name string, val *string) error { + if *val == "" { + return newHelpErrorf("-%s is required", name) + } + return nil +} diff --git a/cmd/nebula-cert/main_test.go b/cmd/nebula-cert/main_test.go new file mode 100644 index 0000000..0c3360e --- /dev/null +++ b/cmd/nebula-cert/main_test.go @@ -0,0 +1,81 @@ +package main + +import ( + "bytes" + "errors" + "github.com/stretchr/testify/assert" + "io" + "os" + "testing" +) + +//TODO: all flag parsing continueOnError will print to stderr on its own currently + +func Test_help(t *testing.T) { + expected := "Usage of " + os.Args[0] + " :\n" + + " Global flags:\n" + + " -version: Prints the version\n" + + " -h, -help: Prints this help message\n\n" + + " Modes:\n" + + " " + caSummary() + "\n" + + " " + keygenSummary() + "\n" + + " " + signSummary() + "\n" + + " " + printSummary() + "\n" + + " " + verifySummary() + "\n" + + ob := &bytes.Buffer{} + + // No error test + help("", ob) + assert.Equal( + t, + expected, + ob.String(), + ) + + // Error test + ob.Reset() + help("test error", ob) + assert.Equal( + t, + "Error: test error\n\n"+expected, + ob.String(), + ) +} + +func Test_handleError(t *testing.T) { + ob := &bytes.Buffer{} + + // normal error + handleError("", errors.New("test error"), ob) + assert.Equal(t, "Error: test error\n", ob.String()) + + // unknown mode help error + ob.Reset() + handleError("", newHelpErrorf("test %s", "error"), ob) + assert.Equal(t, "Error: test error\n", ob.String()) + + // test all modes with help error + modes := map[string]func(io.Writer){"ca": caHelp, "print": printHelp, "sign": signHelp, "verify": verifyHelp} + eb := &bytes.Buffer{} + for mode, fn := range modes { + ob.Reset() + eb.Reset() + fn(eb) + + handleError(mode, newHelpErrorf("test %s", "error"), ob) + assert.Equal(t, "Error: test error\n"+eb.String(), ob.String()) + } + +} + +func assertHelpError(t *testing.T, err error, msg string) { + switch err.(type) { + case *helpError: + // good + default: + t.Fatal("err was not a helpError") + } + + assert.EqualError(t, err, msg) +} diff --git a/cmd/nebula-cert/print.go b/cmd/nebula-cert/print.go new file mode 100644 index 0000000..8ab747a --- /dev/null +++ b/cmd/nebula-cert/print.go @@ -0,0 +1,80 @@ +package main + +import ( + "encoding/json" + "flag" + "fmt" + "io" + "io/ioutil" + "os" + "github.com/slackhq/nebula/cert" + "strings" +) + +type printFlags struct { + set *flag.FlagSet + json *bool + path *string +} + +func newPrintFlags() *printFlags { + pf := printFlags{set: flag.NewFlagSet("print", flag.ContinueOnError)} + pf.set.Usage = func() {} + pf.json = pf.set.Bool("json", false, "Optional: outputs certificates in json format") + pf.path = pf.set.String("path", "", "Required: path to the certificate") + + return &pf +} + +func printCert(args []string, out io.Writer, errOut io.Writer) error { + pf := newPrintFlags() + err := pf.set.Parse(args) + if err != nil { + return err + } + + if err := mustFlagString("path", pf.path); err != nil { + return err + } + + rawCert, err := ioutil.ReadFile(*pf.path) + if err != nil { + return fmt.Errorf("unable to read cert; %s", err) + } + + var c *cert.NebulaCertificate + + for { + c, rawCert, err = cert.UnmarshalNebulaCertificateFromPEM(rawCert) + if err != nil { + return fmt.Errorf("error while unmarshaling cert: %s", err) + } + + if *pf.json { + b, _ := json.Marshal(c) + out.Write(b) + out.Write([]byte("\n")) + + } else { + out.Write([]byte(c.String())) + out.Write([]byte("\n")) + } + + if rawCert == nil || len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" { + break + } + } + + return nil +} + +func printSummary() string { + return "print : prints details about a certificate" +} + +func printHelp(out io.Writer) { + pf := newPrintFlags() + out.Write([]byte("Usage of " + os.Args[0] + " " + printSummary() + "\n")) + pf.set.SetOutput(out) + pf.set.PrintDefaults() +} diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go new file mode 100644 index 0000000..ddc5fff --- /dev/null +++ b/cmd/nebula-cert/print_test.go @@ -0,0 +1,119 @@ +package main + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "io/ioutil" + "os" + "github.com/slackhq/nebula/cert" + "testing" + "time" +) + +func Test_printSummary(t *testing.T) { + assert.Equal(t, "print : prints details about a certificate", printSummary()) +} + +func Test_printHelp(t *testing.T) { + ob := &bytes.Buffer{} + printHelp(ob) + assert.Equal( + t, + "Usage of "+os.Args[0]+" print : prints details about a certificate\n"+ + " -json\n"+ + " \tOptional: outputs certificates in json format\n"+ + " -path string\n"+ + " \tRequired: path to the certificate\n", + ob.String(), + ) +} + +func Test_printCert(t *testing.T) { + // Orient our local time and avoid headaches + time.Local = time.UTC + ob := &bytes.Buffer{} + eb := &bytes.Buffer{} + + // no path + err := printCert([]string{}, ob, eb) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + assertHelpError(t, err, "-path is required") + + // no cert at path + ob.Reset() + eb.Reset() + err = printCert([]string{"-path", "does_not_exist"}, ob, eb) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + assert.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError) + + // invalid cert at path + ob.Reset() + eb.Reset() + tf, err := ioutil.TempFile("", "print-cert") + assert.Nil(t, err) + defer os.Remove(tf.Name()) + + tf.WriteString("-----BEGIN NOPE-----") + err = printCert([]string{"-path", tf.Name()}, ob, eb) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + assert.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block") + + // test multiple certs + ob.Reset() + eb.Reset() + tf.Truncate(0) + tf.Seek(0, 0) + c := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "test", + Groups: []string{"hi"}, + PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, + }, + Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, + } + + p, _ := c.MarshalToPEM() + tf.Write(p) + tf.Write(p) + tf.Write(p) + + err = printCert([]string{"-path", tf.Name()}, ob, eb) + assert.Nil(t, err) + assert.Equal( + t, + "NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\n", + ob.String(), + ) + assert.Equal(t, "", eb.String()) + + // test json + ob.Reset() + eb.Reset() + tf.Truncate(0) + tf.Seek(0, 0) + c = cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "test", + Groups: []string{"hi"}, + PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, + }, + Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}, + } + + p, _ = c.MarshalToPEM() + tf.Write(p) + tf.Write(p) + tf.Write(p) + + err = printCert([]string{"-json", "-path", tf.Name()}, ob, eb) + assert.Nil(t, err) + assert.Equal( + t, + "{\"details\":{\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n", + ob.String(), + ) + assert.Equal(t, "", eb.String()) +} diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go new file mode 100644 index 0000000..89493dc --- /dev/null +++ b/cmd/nebula-cert/sign.go @@ -0,0 +1,227 @@ +package main + +import ( + "crypto/rand" + "flag" + "fmt" + "io" + "io/ioutil" + "net" + "os" + "strings" + "time" + + "golang.org/x/crypto/curve25519" + "github.com/slackhq/nebula/cert" +) + +type signFlags struct { + set *flag.FlagSet + caKeyPath *string + caCertPath *string + name *string + ip *string + duration *time.Duration + inPubPath *string + outKeyPath *string + outCertPath *string + groups *string + subnets *string +} + +func newSignFlags() *signFlags { + sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} + sf.set.Usage = func() {} + sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") + sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") + sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") + sf.ip = sf.set.String("ip", "", "Required: ip and network in CIDR notation to assign the cert") + sf.duration = sf.set.Duration("duration", 0, "Required: how long the cert should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") + sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key") + sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to") + sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to") + sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups") + sf.subnets = sf.set.String("subnets", "", "Optional: comma seperated list of subnet this cert can serve for") + return &sf + +} + +func signCert(args []string, out io.Writer, errOut io.Writer) error { + sf := newSignFlags() + err := sf.set.Parse(args) + if err != nil { + return err + } + + if err := mustFlagString("ca-key", sf.caKeyPath); err != nil { + return err + } + if err := mustFlagString("ca-crt", sf.caCertPath); err != nil { + return err + } + if err := mustFlagString("name", sf.name); err != nil { + return err + } + if err := mustFlagString("ip", sf.ip); err != nil { + return err + } + if *sf.inPubPath != "" && *sf.outKeyPath != "" { + return newHelpErrorf("cannot set both -in-pub and -out-key") + } + + rawCAKey, err := ioutil.ReadFile(*sf.caKeyPath) + if err != nil { + return fmt.Errorf("error while reading ca-key: %s", err) + } + + caKey, _, err := cert.UnmarshalEd25519PrivateKey(rawCAKey) + if err != nil { + return fmt.Errorf("error while parsing ca-key: %s", err) + } + + rawCACert, err := ioutil.ReadFile(*sf.caCertPath) + if err != nil { + return fmt.Errorf("error while reading ca-crt: %s", err) + } + + caCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCACert) + if err != nil { + return fmt.Errorf("error while parsing ca-crt: %s", err) + } + + issuer, err := caCert.Sha256Sum() + if err != nil { + return fmt.Errorf("error while getting -ca-crt fingerprint: %s", err) + } + + if caCert.Expired(time.Now()) { + return fmt.Errorf("ca certificate is expired") + } + + // if no duration is given, expire one second before the root expires + if *sf.duration <= 0 { + *sf.duration = time.Until(caCert.Details.NotAfter) - time.Second*1 + } + + if caCert.Details.NotAfter.Before(time.Now().Add(*sf.duration)) { + return fmt.Errorf("refusing to generate certificate with duration beyond root expiration: %s", caCert.Details.NotAfter) + } + + ip, ipNet, err := net.ParseCIDR(*sf.ip) + if err != nil { + return newHelpErrorf("invalid ip definition: %s", err) + } + ipNet.IP = ip + + groups := []string{} + if *sf.groups != "" { + for _, rg := range strings.Split(*sf.groups, ",") { + g := strings.TrimSpace(rg) + if g != "" { + groups = append(groups, g) + } + } + } + + subnets := []*net.IPNet{} + if *sf.subnets != "" { + for _, rs := range strings.Split(*sf.subnets, ",") { + rs := strings.Trim(rs, " ") + if rs != "" { + _, s, err := net.ParseCIDR(rs) + if err != nil { + return newHelpErrorf("invalid subnet definition: %s", err) + } + subnets = append(subnets, s) + } + } + } + + var pub, rawPriv []byte + if *sf.inPubPath != "" { + rawPub, err := ioutil.ReadFile(*sf.inPubPath) + if err != nil { + return fmt.Errorf("error while reading in-pub: %s", err) + } + pub, _, err = cert.UnmarshalX25519PublicKey(rawPub) + if err != nil { + return fmt.Errorf("error while parsing in-pub: %s", err) + } + } else { + pub, rawPriv = x25519Keypair() + } + + nc := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: *sf.name, + Ips: []*net.IPNet{ipNet}, + Groups: groups, + Subnets: subnets, + NotBefore: time.Now(), + NotAfter: time.Now().Add(*sf.duration), + PublicKey: pub, + IsCA: false, + Issuer: issuer, + }, + } + + if *sf.outKeyPath == "" { + *sf.outKeyPath = *sf.name + ".key" + } + + if *sf.outCertPath == "" { + *sf.outCertPath = *sf.name + ".crt" + } + + if _, err := os.Stat(*sf.outKeyPath); err == nil { + return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath) + } + + if _, err := os.Stat(*sf.outCertPath); err == nil { + return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) + } + + err = nc.Sign(caKey) + if err != nil { + return fmt.Errorf("error while signing: %s", err) + } + + if *sf.inPubPath == "" { + err = ioutil.WriteFile(*sf.outKeyPath, cert.MarshalX25519PrivateKey(rawPriv), 0600) + if err != nil { + return fmt.Errorf("error while writing out-key: %s", err) + } + } + + b, err := nc.MarshalToPEM() + if err != nil { + return fmt.Errorf("error while marshalling certificate: %s", err) + } + + err = ioutil.WriteFile(*sf.outCertPath, b, 0600) + if err != nil { + return fmt.Errorf("error while writing out-crt: %s", err) + } + + return nil +} + +func x25519Keypair() ([]byte, []byte) { + var pubkey, privkey [32]byte + if _, err := io.ReadFull(rand.Reader, privkey[:]); err != nil { + panic(err) + } + curve25519.ScalarBaseMult(&pubkey, &privkey) + return pubkey[:], privkey[:] +} + +func signSummary() string { + return "sign : create and sign a certificate" +} + +func signHelp(out io.Writer) { + sf := newSignFlags() + out.Write([]byte("Usage of " + os.Args[0] + " " + signSummary() + "\n")) + sf.set.SetOutput(out) + sf.set.PrintDefaults() +} diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go new file mode 100644 index 0000000..a278545 --- /dev/null +++ b/cmd/nebula-cert/sign_test.go @@ -0,0 +1,281 @@ +package main + +import ( + "bytes" + "crypto/rand" + "io/ioutil" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" + "github.com/slackhq/nebula/cert" +) + +//TODO: test file permissions + +func Test_signSummary(t *testing.T) { + assert.Equal(t, "sign : create and sign a certificate", signSummary()) +} + +func Test_signHelp(t *testing.T) { + ob := &bytes.Buffer{} + signHelp(ob) + assert.Equal( + t, + "Usage of "+os.Args[0]+" sign : create and sign a certificate\n"+ + " -ca-crt string\n"+ + " \tOptional: path to the signing CA cert (default \"ca.crt\")\n"+ + " -ca-key string\n"+ + " \tOptional: path to the signing CA key (default \"ca.key\")\n"+ + " -duration duration\n"+ + " \tRequired: how long the cert should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"\n"+ + " -groups string\n"+ + " \tOptional: comma separated list of groups\n"+ + " -in-pub string\n"+ + " \tOptional (if out-key not set): path to read a previously generated public key\n"+ + " -ip string\n"+ + " \tRequired: ip and network in CIDR notation to assign the cert\n"+ + " -name string\n"+ + " \tRequired: name of the cert, usually a hostname\n"+ + " -out-crt string\n"+ + " \tOptional: path to write the certificate to\n"+ + " -out-key string\n"+ + " \tOptional (if in-pub not set): path to write the private key to\n"+ + " -subnets string\n"+ + " \tOptional: comma seperated list of subnet this cert can serve for\n", + ob.String(), + ) +} + +func Test_signCert(t *testing.T) { + ob := &bytes.Buffer{} + eb := &bytes.Buffer{} + + // required args + + assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb), "-name is required") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb), "-ip is required") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // cannot set -in-pub and -out-key + assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb), "cannot set both -in-pub and -out-key") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // failed to read key + ob.Reset() + eb.Reset() + args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assert.EqualError(t, signCert(args, ob, eb), "error while reading ca-key: open ./nope: "+NoSuchFileError) + + // failed to unmarshal key + ob.Reset() + eb.Reset() + caKeyF, err := ioutil.TempFile("", "sign-cert.key") + assert.Nil(t, err) + defer os.Remove(caKeyF.Name()) + + args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assert.EqualError(t, signCert(args, ob, eb), "error while parsing ca-key: input did not contain a valid PEM encoded block") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // Write a proper ca key for later + ob.Reset() + eb.Reset() + caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) + caKeyF.Write(cert.MarshalEd25519PrivateKey(caPriv)) + + // failed to read cert + args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assert.EqualError(t, signCert(args, ob, eb), "error while reading ca-crt: open ./nope: "+NoSuchFileError) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // failed to unmarshal cert + ob.Reset() + eb.Reset() + caCrtF, err := ioutil.TempFile("", "sign-cert.crt") + assert.Nil(t, err) + defer os.Remove(caCrtF.Name()) + + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assert.EqualError(t, signCert(args, ob, eb), "error while parsing ca-crt: input did not contain a valid PEM encoded block") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // write a proper ca cert for later + ca := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "ca", + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Minute * 200), + PublicKey: caPub, + IsCA: true, + }, + } + b, _ := ca.MarshalToPEM() + caCrtF.Write(b) + + // failed to read pub + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} + assert.EqualError(t, signCert(args, ob, eb), "error while reading in-pub: open ./nope: "+NoSuchFileError) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // failed to unmarshal pub + ob.Reset() + eb.Reset() + inPubF, err := ioutil.TempFile("", "in.pub") + assert.Nil(t, err) + defer os.Remove(inPubF.Name()) + + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} + assert.EqualError(t, signCert(args, ob, eb), "error while parsing in-pub: input did not contain a valid PEM encoded block") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // write a proper pub for later + ob.Reset() + eb.Reset() + inPub, _ := x25519Keypair() + inPubF.Write(cert.MarshalX25519PublicKey(inPub)) + + // bad ip cidr + ob.Reset() + eb.Reset() + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb), "invalid ip definition: invalid CIDR address: a1.1.1.1/24") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // bad subnet cidr + ob.Reset() + eb.Reset() + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} + assertHelpError(t, signCert(args, ob, eb), "invalid subnet definition: invalid CIDR address: a") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // failed key write + ob.Reset() + eb.Reset() + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} + assert.EqualError(t, signCert(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // create temp key file + keyF, err := ioutil.TempFile("", "test.key") + assert.Nil(t, err) + os.Remove(keyF.Name()) + + // failed cert write + ob.Reset() + eb.Reset() + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} + assert.EqualError(t, signCert(args, ob, eb), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + os.Remove(keyF.Name()) + + // create temp cert file + crtF, err := ioutil.TempFile("", "test.crt") + assert.Nil(t, err) + os.Remove(crtF.Name()) + + // test proper cert with removed empty groups and subnets + ob.Reset() + eb.Reset() + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.Nil(t, signCert(args, ob, eb)) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // read cert and key files + rb, _ := ioutil.ReadFile(keyF.Name()) + lKey, b, err := cert.UnmarshalX25519PrivateKey(rb) + assert.Len(t, b, 0) + assert.Nil(t, err) + assert.Len(t, lKey, 32) + + rb, _ = ioutil.ReadFile(crtF.Name()) + lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb) + assert.Len(t, b, 0) + assert.Nil(t, err) + + assert.Equal(t, "test", lCrt.Details.Name) + assert.Equal(t, "1.1.1.1/24", lCrt.Details.Ips[0].String()) + assert.Len(t, lCrt.Details.Ips, 1) + assert.False(t, lCrt.Details.IsCA) + assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups) + assert.Len(t, lCrt.Details.Subnets, 3) + assert.Len(t, lCrt.Details.PublicKey, 32) + assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore)) + + sns := []string{} + for _, sn := range lCrt.Details.Subnets { + sns = append(sns, sn.String()) + } + assert.Equal(t, []string{"10.1.1.1/32", "10.2.2.2/32", "10.5.5.5/32"}, sns) + + issuer, _ := ca.Sha256Sum() + assert.Equal(t, issuer, lCrt.Details.Issuer) + + assert.True(t, lCrt.CheckSignature(caPub)) + + // test proper cert with in-pub + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} + assert.Nil(t, signCert(args, ob, eb)) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // read cert file and check pub key matches in-pub + rb, _ = ioutil.ReadFile(crtF.Name()) + lCrt, b, err = cert.UnmarshalNebulaCertificateFromPEM(rb) + assert.Len(t, b, 0) + assert.Nil(t, err) + assert.Equal(t, lCrt.Details.PublicKey, inPub) + + // test refuse to sign cert with duration beyond root + ob.Reset() + eb.Reset() + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.EqualError(t, signCert(args, ob, eb), "refusing to generate certificate with duration beyond root expiration: "+ca.Details.NotAfter.Format("2006-01-02 15:04:05 +0000 UTC")) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // create valid cert/key for overwrite tests + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.Nil(t, signCert(args, ob, eb)) + + // test that we won't overwrite existing key file + ob.Reset() + eb.Reset() + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.EqualError(t, signCert(args, ob, eb), "refusing to overwrite existing key: "+keyF.Name()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // test that we won't overwrite existing certificate file + os.Remove(keyF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + assert.EqualError(t, signCert(args, ob, eb), "refusing to overwrite existing cert: "+crtF.Name()) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + +} diff --git a/cmd/nebula-cert/test_unix.go b/cmd/nebula-cert/test_unix.go new file mode 100644 index 0000000..7276dfa --- /dev/null +++ b/cmd/nebula-cert/test_unix.go @@ -0,0 +1,4 @@ +package main + +const NoSuchFileError = "no such file or directory" +const NoSuchDirError = "no such file or directory" diff --git a/cmd/nebula-cert/test_windows.go b/cmd/nebula-cert/test_windows.go new file mode 100644 index 0000000..f364296 --- /dev/null +++ b/cmd/nebula-cert/test_windows.go @@ -0,0 +1,4 @@ +package main + +const NoSuchFileError = "The system cannot find the file specified." +const NoSuchDirError = "The system cannot find the path specified." diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go new file mode 100644 index 0000000..5574da4 --- /dev/null +++ b/cmd/nebula-cert/verify.go @@ -0,0 +1,86 @@ +package main + +import ( + "flag" + "fmt" + "io" + "io/ioutil" + "os" + "github.com/slackhq/nebula/cert" + "strings" + "time" +) + +type verifyFlags struct { + set *flag.FlagSet + caPath *string + certPath *string +} + +func newVerifyFlags() *verifyFlags { + vf := verifyFlags{set: flag.NewFlagSet("verify", flag.ContinueOnError)} + vf.set.Usage = func() {} + vf.caPath = vf.set.String("ca", "", "Required: path to a file containing one or more ca certificates") + vf.certPath = vf.set.String("crt", "", "Required: path to a file containing a single certificate") + return &vf +} + +func verify(args []string, out io.Writer, errOut io.Writer) error { + vf := newVerifyFlags() + err := vf.set.Parse(args) + if err != nil { + return err + } + + if err := mustFlagString("ca", vf.caPath); err != nil { + return err + } + if err := mustFlagString("crt", vf.certPath); err != nil { + return err + } + + rawCACert, err := ioutil.ReadFile(*vf.caPath) + if err != nil { + return fmt.Errorf("error while reading ca: %s", err) + } + + caPool := cert.NewCAPool() + for { + rawCACert, err = caPool.AddCACertificate(rawCACert) + if err != nil { + return fmt.Errorf("error while adding ca cert to pool: %s", err) + } + + if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { + break + } + } + + rawCert, err := ioutil.ReadFile(*vf.certPath) + if err != nil { + return fmt.Errorf("unable to read crt; %s", err) + } + + c, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) + if err != nil { + return fmt.Errorf("error while parsing crt: %s", err) + } + + good, err := c.Verify(time.Now(), caPool) + if !good { + return err + } + + return nil +} + +func verifySummary() string { + return "verify : verifies a certificate isn't expired and was signed by a trusted authority." +} + +func verifyHelp(out io.Writer) { + vf := newVerifyFlags() + out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) + vf.set.SetOutput(out) + vf.set.PrintDefaults() +} diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go new file mode 100644 index 0000000..5722662 --- /dev/null +++ b/cmd/nebula-cert/verify_test.go @@ -0,0 +1,141 @@ +package main + +import ( + "bytes" + "crypto/rand" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" + "io/ioutil" + "os" + "github.com/slackhq/nebula/cert" + "testing" + "time" +) + +func Test_verifySummary(t *testing.T) { + assert.Equal(t, "verify : verifies a certificate isn't expired and was signed by a trusted authority.", verifySummary()) +} + +func Test_verifyHelp(t *testing.T) { + ob := &bytes.Buffer{} + verifyHelp(ob) + assert.Equal( + t, + "Usage of "+os.Args[0]+" verify : verifies a certificate isn't expired and was signed by a trusted authority.\n"+ + " -ca string\n"+ + " \tRequired: path to a file containing one or more ca certificates\n"+ + " -crt string\n"+ + " \tRequired: path to a file containing a single certificate\n", + ob.String(), + ) +} + +func Test_verify(t *testing.T) { + time.Local = time.UTC + ob := &bytes.Buffer{} + eb := &bytes.Buffer{} + + // required args + assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required") + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + + assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required") + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + + // no ca at path + ob.Reset() + eb.Reset() + err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + assert.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError) + + // invalid ca at path + ob.Reset() + eb.Reset() + caFile, err := ioutil.TempFile("", "verify-ca") + assert.Nil(t, err) + defer os.Remove(caFile.Name()) + + caFile.WriteString("-----BEGIN NOPE-----") + err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + assert.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") + + // make a ca for later + caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) + ca := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "test-ca", + NotBefore: time.Now().Add(time.Hour * -1), + NotAfter: time.Now().Add(time.Hour), + PublicKey: caPub, + IsCA: true, + }, + } + ca.Sign(caPriv) + b, _ := ca.MarshalToPEM() + caFile.Truncate(0) + caFile.Seek(0, 0) + caFile.Write(b) + + // no crt at path + err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError) + + // invalid crt at path + ob.Reset() + eb.Reset() + certFile, err := ioutil.TempFile("", "verify-cert") + assert.Nil(t, err) + defer os.Remove(certFile.Name()) + + certFile.WriteString("-----BEGIN NOPE-----") + err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") + + // unverifiable cert at path + _, badPriv, _ := ed25519.GenerateKey(rand.Reader) + certPub, _ := x25519Keypair() + signer, _ := ca.Sha256Sum() + crt := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "test-cert", + NotBefore: time.Now().Add(time.Hour * -1), + NotAfter: time.Now().Add(time.Hour), + PublicKey: certPub, + IsCA: false, + Issuer: signer, + }, + } + + crt.Sign(badPriv) + b, _ = crt.MarshalToPEM() + certFile.Truncate(0) + certFile.Seek(0, 0) + certFile.Write(b) + + err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + assert.EqualError(t, err, "certificate signature did not match") + + // verified cert at path + crt.Sign(caPriv) + b, _ = crt.MarshalToPEM() + certFile.Truncate(0) + certFile.Seek(0, 0) + certFile.Write(b) + + err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) + assert.Equal(t, "", ob.String()) + assert.Equal(t, "", eb.String()) + assert.Nil(t, err) +} diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go new file mode 100644 index 0000000..fa42d94 --- /dev/null +++ b/cmd/nebula/main.go @@ -0,0 +1,43 @@ +package main + +import ( + "flag" + "fmt" + "os" + + "github.com/slackhq/nebula" +) + +// A version string that can be set with +// +// -ldflags "-X main.Build=SOMEVERSION" +// +// at compile-time. +var Build string + +func main() { + configPath := flag.String("config", "", "Path to either a file or directory to load configuration from") + configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config") + printVersion := flag.Bool("version", false, "Print version") + printUsage := flag.Bool("help", false, "Print command line usage") + + flag.Parse() + + if *printVersion { + fmt.Printf("Build: %s\n", Build) + os.Exit(0) + } + + if *printUsage { + flag.Usage() + os.Exit(0) + } + + if *configPath == "" { + fmt.Println("-config flag must be set") + flag.Usage() + os.Exit(1) + } + + nebula.Main(*configPath, *configTest, Build) +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..e77ccf0 --- /dev/null +++ b/config.go @@ -0,0 +1,338 @@ +package nebula + +import ( + "fmt" + "github.com/imdario/mergo" + "github.com/sirupsen/logrus" + "gopkg.in/yaml.v2" + "io/ioutil" + "os" + "os/signal" + "path/filepath" + "sort" + "strconv" + "strings" + "syscall" + "time" +) + +type Config struct { + path string + files []string + Settings map[interface{}]interface{} + oldSettings map[interface{}]interface{} + callbacks []func(*Config) +} + +func NewConfig() *Config { + return &Config{ + Settings: make(map[interface{}]interface{}), + } +} + +// Load will find all yaml files within path and load them in lexical order +func (c *Config) Load(path string) error { + c.path = path + c.files = make([]string, 0) + + err := c.resolve(path) + if err != nil { + return err + } + + sort.Strings(c.files) + + err = c.parse() + if err != nil { + return err + } + + return nil +} + +// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered +// here should decide if they need to make a change to the current process before making the change. HasChanged can be +// used to help decide if a change is necessary. +// These functions should return quickly or spawn their own go routine if they will take a while +func (c *Config) RegisterReloadCallback(f func(*Config)) { + c.callbacks = append(c.callbacks, f) +} + +// HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of +// k in both the old and new settings will be serialized, the result of the string comparison is returned. +// If k is an empty string the entire config is tested. +// It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating +// there is change when there actually wasn't any. +func (c *Config) HasChanged(k string) bool { + if c.oldSettings == nil { + return false + } + + var ( + nv interface{} + ov interface{} + ) + + if k == "" { + nv = c.Settings + ov = c.oldSettings + k = "all settings" + } else { + nv = c.get(k, c.Settings) + ov = c.get(k, c.oldSettings) + } + + newVals, err := yaml.Marshal(nv) + if err != nil { + l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config") + } + + oldVals, err := yaml.Marshal(ov) + if err != nil { + l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config") + } + + return string(newVals) != string(oldVals) +} + +// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the +// original path provided to Load. The old settings are shallow copied for change detection after the reload. +func (c *Config) CatchHUP() { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGHUP) + + go func() { + for range ch { + l.Info("Caught HUP, reloading config") + c.ReloadConfig() + } + }() +} + +func (c *Config) ReloadConfig() { + c.oldSettings = make(map[interface{}]interface{}) + for k, v := range c.Settings { + c.oldSettings[k] = v + } + + err := c.Load(c.path) + if err != nil { + l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config") + return + } + + for _, v := range c.callbacks { + v(c) + } +} + +// GetString will get the string for k or return the default d if not found or invalid +func (c *Config) GetString(k, d string) string { + r := c.Get(k) + if r == nil { + return d + } + + return fmt.Sprintf("%v", r) +} + +// GetStringSlice will get the slice of strings for k or return the default d if not found or invalid +func (c *Config) GetStringSlice(k string, d []string) []string { + r := c.Get(k) + if r == nil { + return d + } + + rv, ok := r.([]interface{}) + if !ok { + return d + } + + v := make([]string, len(rv)) + for i := 0; i < len(v); i++ { + v[i] = fmt.Sprintf("%v", rv[i]) + } + + return v +} + +// GetMap will get the map for k or return the default d if not found or invalid +func (c *Config) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} { + r := c.Get(k) + if r == nil { + return d + } + + v, ok := r.(map[interface{}]interface{}) + if !ok { + return d + } + + return v +} + +// GetInt will get the int for k or return the default d if not found or invalid +func (c *Config) GetInt(k string, d int) int { + r := c.GetString(k, strconv.Itoa(d)) + v, err := strconv.Atoi(r) + if err != nil { + return d + } + + return v +} + +// GetBool will get the bool for k or return the default d if not found or invalid +func (c *Config) GetBool(k string, d bool) bool { + r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d))) + v, err := strconv.ParseBool(r) + if err != nil { + switch r { + case "y", "yes": + return true + case "n", "no": + return false + } + return d + } + + return v +} + +// GetDuration will get the duration for k or return the default d if not found or invalid +func (c *Config) GetDuration(k string, d time.Duration) time.Duration { + r := c.GetString(k, "") + v, err := time.ParseDuration(r) + if err != nil { + return d + } + return v +} + +func (c *Config) Get(k string) interface{} { + return c.get(k, c.Settings) +} + +func (c *Config) get(k string, v interface{}) interface{} { + parts := strings.Split(k, ".") + for _, p := range parts { + m, ok := v.(map[interface{}]interface{}) + if !ok { + return nil + } + + v, ok = m[p] + if !ok { + return nil + } + } + + return v +} + +func (c *Config) resolve(path string) error { + i, err := os.Stat(path) + if err != nil { + return nil + } + + if !i.IsDir() { + c.addFile(path) + return nil + } + + paths, err := readDirNames(path) + if err != nil { + return fmt.Errorf("problem while reading directory %s: %s", path, err) + } + + for _, p := range paths { + err := c.resolve(filepath.Join(path, p)) + if err != nil { + return err + } + } + + return nil +} + +func (c *Config) addFile(path string) error { + ext := filepath.Ext(path) + + if ext != ".yaml" && ext != ".yml" { + return nil + } + + ap, err := filepath.Abs(path) + if err != nil { + return err + } + + c.files = append(c.files, ap) + return nil +} + +func (c *Config) parse() error { + var m map[interface{}]interface{} + + for _, path := range c.files { + b, err := ioutil.ReadFile(path) + if err != nil { + return err + } + + var nm map[interface{}]interface{} + err = yaml.Unmarshal(b, &nm) + if err != nil { + return err + } + + // We need to use WithAppendSlice so that firewall rules in separate + // files are appended together + err = mergo.Merge(&nm, m, mergo.WithAppendSlice) + m = nm + if err != nil { + return err + } + } + + c.Settings = m + return nil +} + +func readDirNames(path string) ([]string, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + + paths, err := f.Readdirnames(-1) + f.Close() + if err != nil { + return nil, err + } + + sort.Strings(paths) + return paths, nil +} + +func configLogger(c *Config) error { + // set up our logging level + logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) + if err != nil { + return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels) + } + l.SetLevel(logLevel) + + logFormat := strings.ToLower(c.GetString("logging.format", "text")) + switch logFormat { + case "text": + l.Formatter = &logrus.TextFormatter{} + case "json": + l.Formatter = &logrus.JSONFormatter{} + default: + return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) + } + + return nil +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..0e1036c --- /dev/null +++ b/config_test.go @@ -0,0 +1,141 @@ +package nebula + +import ( + "github.com/stretchr/testify/assert" + "io/ioutil" + "os" + "path/filepath" + "testing" + "time" +) + +func TestConfig_Load(t *testing.T) { + dir, err := ioutil.TempDir("", "config-test") + // invalid yaml + c := NewConfig() + ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) + assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") + + // simple multi config merge + c = NewConfig() + os.RemoveAll(dir) + os.Mkdir(dir, 0755) + + assert.Nil(t, err) + + ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) + ioutil.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) + assert.Nil(t, c.Load(dir)) + expected := map[interface{}]interface{}{ + "outer": map[interface{}]interface{}{ + "inner": "override", + }, + "new": "hi", + } + assert.Equal(t, expected, c.Settings) + + //TODO: test symlinked file + //TODO: test symlinked directory +} + +func TestConfig_Get(t *testing.T) { + // test simple type + c := NewConfig() + c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"} + assert.Equal(t, "hi", c.Get("firewall.outbound")) + + // test complex type + inner := []map[interface{}]interface{}{{"port": "1", "code": "2"}} + c.Settings["firewall"] = map[interface{}]interface{}{"outbound": inner} + assert.EqualValues(t, inner, c.Get("firewall.outbound")) + + // test missing + assert.Nil(t, c.Get("firewall.nope")) +} + +func TestConfig_GetStringSlice(t *testing.T) { + c := NewConfig() + c.Settings["slice"] = []interface{}{"one", "two"} + assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) +} + +func TestConfig_GetBool(t *testing.T) { + c := NewConfig() + c.Settings["bool"] = true + assert.Equal(t, true, c.GetBool("bool", false)) + + c.Settings["bool"] = "true" + assert.Equal(t, true, c.GetBool("bool", false)) + + c.Settings["bool"] = false + assert.Equal(t, false, c.GetBool("bool", true)) + + c.Settings["bool"] = "false" + assert.Equal(t, false, c.GetBool("bool", true)) + + c.Settings["bool"] = "Y" + assert.Equal(t, true, c.GetBool("bool", false)) + + c.Settings["bool"] = "yEs" + assert.Equal(t, true, c.GetBool("bool", false)) + + c.Settings["bool"] = "N" + assert.Equal(t, false, c.GetBool("bool", true)) + + c.Settings["bool"] = "nO" + assert.Equal(t, false, c.GetBool("bool", true)) +} + +func TestConfig_HasChanged(t *testing.T) { + // No reload has occurred, return false + c := NewConfig() + c.Settings["test"] = "hi" + assert.False(t, c.HasChanged("")) + + // Test key change + c = NewConfig() + c.Settings["test"] = "hi" + c.oldSettings = map[interface{}]interface{}{"test": "no"} + assert.True(t, c.HasChanged("test")) + assert.True(t, c.HasChanged("")) + + // No key change + c = NewConfig() + c.Settings["test"] = "hi" + c.oldSettings = map[interface{}]interface{}{"test": "hi"} + assert.False(t, c.HasChanged("test")) + assert.False(t, c.HasChanged("")) +} + +func TestConfig_ReloadConfig(t *testing.T) { + done := make(chan bool, 1) + dir, err := ioutil.TempDir("", "config-test") + assert.Nil(t, err) + ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) + + c := NewConfig() + assert.Nil(t, c.Load(dir)) + + assert.False(t, c.HasChanged("outer.inner")) + assert.False(t, c.HasChanged("outer")) + assert.False(t, c.HasChanged("")) + + ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: ho"), 0644) + + c.RegisterReloadCallback(func(c *Config) { + done <- true + }) + + c.ReloadConfig() + assert.True(t, c.HasChanged("outer.inner")) + assert.True(t, c.HasChanged("outer")) + assert.True(t, c.HasChanged("")) + + // Make sure we call the callbacks + select { + case <-done: + case <-time.After(1 * time.Second): + panic("timeout") + } + +} diff --git a/connection_manager.go b/connection_manager.go new file mode 100644 index 0000000..33d3265 --- /dev/null +++ b/connection_manager.go @@ -0,0 +1,253 @@ +package nebula + +import ( + "github.com/sirupsen/logrus" + "sync" + "time" +) + +// TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet +// and something like every 10 packets we could lock, send 10, then unlock for a moment + +type connectionManager struct { + hostMap *HostMap + in map[uint32]struct{} + inLock *sync.RWMutex + inCount int + out map[uint32]struct{} + outLock *sync.RWMutex + outCount int + TrafficTimer *SystemTimerWheel + intf *Interface + + pendingDeletion map[uint32]int + pendingDeletionLock *sync.RWMutex + pendingDeletionTimer *SystemTimerWheel + + checkInterval int + pendingDeletionInterval int + + // I wanted to call one matLock +} + +func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager { + nc := &connectionManager{ + hostMap: intf.hostMap, + in: make(map[uint32]struct{}), + inLock: &sync.RWMutex{}, + inCount: 0, + out: make(map[uint32]struct{}), + outLock: &sync.RWMutex{}, + outCount: 0, + TrafficTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60), + intf: intf, + pendingDeletion: make(map[uint32]int), + pendingDeletionLock: &sync.RWMutex{}, + pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60), + checkInterval: checkInterval, + pendingDeletionInterval: pendingDeletionInterval, + } + nc.Start() + return nc +} + +func (n *connectionManager) In(ip uint32) { + n.inLock.RLock() + // If this already exists, return + if _, ok := n.in[ip]; ok { + n.inLock.RUnlock() + return + } + n.inLock.RUnlock() + n.inLock.Lock() + n.in[ip] = struct{}{} + n.inLock.Unlock() +} + +func (n *connectionManager) Out(ip uint32) { + n.outLock.RLock() + // If this already exists, return + if _, ok := n.out[ip]; ok { + n.outLock.RUnlock() + return + } + n.outLock.RUnlock() + n.outLock.Lock() + // double check since we dropped the lock temporarily + if _, ok := n.out[ip]; ok { + n.outLock.Unlock() + return + } + n.out[ip] = struct{}{} + n.AddTrafficWatch(ip, n.checkInterval) + n.outLock.Unlock() +} + +func (n *connectionManager) CheckIn(vpnIP uint32) bool { + n.inLock.RLock() + if _, ok := n.in[vpnIP]; ok { + n.inLock.RUnlock() + return true + } + n.inLock.RUnlock() + return false +} + +func (n *connectionManager) ClearIP(ip uint32) { + n.inLock.Lock() + n.outLock.Lock() + delete(n.in, ip) + delete(n.out, ip) + n.inLock.Unlock() + n.outLock.Unlock() +} + +func (n *connectionManager) ClearPendingDeletion(ip uint32) { + n.pendingDeletionLock.Lock() + delete(n.pendingDeletion, ip) + n.pendingDeletionLock.Unlock() +} + +func (n *connectionManager) AddPendingDeletion(ip uint32) { + n.pendingDeletionLock.Lock() + if _, ok := n.pendingDeletion[ip]; ok { + n.pendingDeletion[ip] += 1 + } else { + n.pendingDeletion[ip] = 0 + } + n.pendingDeletionTimer.Add(ip, time.Second*time.Duration(n.pendingDeletionInterval)) + n.pendingDeletionLock.Unlock() +} + +func (n *connectionManager) checkPendingDeletion(ip uint32) bool { + n.pendingDeletionLock.RLock() + if _, ok := n.pendingDeletion[ip]; ok { + + n.pendingDeletionLock.RUnlock() + return true + } + n.pendingDeletionLock.RUnlock() + return false +} + +func (n *connectionManager) AddTrafficWatch(vpnIP uint32, seconds int) { + n.TrafficTimer.Add(vpnIP, time.Second*time.Duration(seconds)) +} + +func (n *connectionManager) Start() { + go n.Run() +} + +func (n *connectionManager) Run() { + clockSource := time.Tick(500 * time.Millisecond) + + for now := range clockSource { + n.HandleMonitorTick(now) + n.HandleDeletionTick(now) + } +} + +func (n *connectionManager) HandleMonitorTick(now time.Time) { + n.TrafficTimer.advance(now) + for { + ep := n.TrafficTimer.Purge() + if ep == nil { + break + } + + vpnIP := ep.(uint32) + + // Check for traffic coming back in from this host. + traf := n.CheckIn(vpnIP) + + // If we saw incoming packets from this ip, just return + if traf { + if l.Level >= logrus.DebugLevel { + l.WithField("vpnIp", IntIp(vpnIP)). + WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). + Debug("Tunnel status") + } + n.ClearIP(vpnIP) + n.ClearPendingDeletion(vpnIP) + continue + } + + // If we didn't we may need to probe or destroy the conn + hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) + if err != nil { + l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) + n.ClearIP(vpnIP) + n.ClearPendingDeletion(vpnIP) + continue + } + + l.WithField("vpnIp", IntIp(vpnIP)). + WithField("tunnelCheck", m{"state": "testing", "method": "active"}). + Debug("Tunnel status") + + if hostinfo != nil && hostinfo.ConnectionState != nil { + // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues + n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + + } else { + l.Debugf("Hostinfo sadness: %s", IntIp(vpnIP)) + } + n.AddPendingDeletion(vpnIP) + } + +} + +func (n *connectionManager) HandleDeletionTick(now time.Time) { + n.pendingDeletionTimer.advance(now) + for { + ep := n.pendingDeletionTimer.Purge() + if ep == nil { + break + } + + vpnIP := ep.(uint32) + + // If we saw incoming packets from this ip, just return + traf := n.CheckIn(vpnIP) + if traf { + l.WithField("vpnIp", IntIp(vpnIP)). + WithField("tunnelCheck", m{"state": "alive", "method": "active"}). + Debug("Tunnel status") + n.ClearIP(vpnIP) + n.ClearPendingDeletion(vpnIP) + continue + } + + hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) + if err != nil { + n.ClearIP(vpnIP) + n.ClearPendingDeletion(vpnIP) + l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) + continue + } + + // If it comes around on deletion wheel and hasn't resolved itself, delete + if n.checkPendingDeletion(vpnIP) { + cn := "" + if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil { + cn = hostinfo.ConnectionState.peerCert.Details.Name + } + l.WithField("vpnIp", IntIp(vpnIP)). + WithField("tunnelCheck", m{"state": "dead", "method": "active"}). + WithField("certName", cn). + Info("Tunnel status") + + n.ClearIP(vpnIP) + n.ClearPendingDeletion(vpnIP) + // TODO: This is only here to let tests work. Should do proper mocking + if n.intf.lightHouse != nil { + n.intf.lightHouse.DeleteVpnIP(vpnIP) + } + n.hostMap.DeleteVpnIP(vpnIP) + n.hostMap.DeleteIndex(hostinfo.localIndexId) + } else { + n.ClearIP(vpnIP) + n.ClearPendingDeletion(vpnIP) + } + } +} diff --git a/connection_manager_test.go b/connection_manager_test.go new file mode 100644 index 0000000..789b8ed --- /dev/null +++ b/connection_manager_test.go @@ -0,0 +1,141 @@ +package nebula + +import ( + "net" + "testing" + "time" + + "github.com/flynn/noise" + "github.com/stretchr/testify/assert" + "github.com/slackhq/nebula/cert" +) + +var vpnIP uint32 = uint32(12341234) + +func Test_NewConnectionManagerTest(t *testing.T) { + //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") + _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") + _, localrange, _ := net.ParseCIDR("10.1.1.1/24") + preferredRanges := []*net.IPNet{localrange} + + // Very incomplete mock objects + hostMap := NewHostMap("test", vpncidr, preferredRanges) + cs := &CertState{ + rawCertificate: []byte{}, + privateKey: []byte{}, + certificate: &cert.NebulaCertificate{}, + rawCertificateNoKey: []byte{}, + } + + lh := NewLightHouse(false, 0, []string{}, 1000, 0, &udpConn{}, false) + ifce := &Interface{ + hostMap: hostMap, + inside: &Tun{}, + outside: &udpConn{}, + certState: cs, + firewall: &Firewall{}, + lightHouse: lh, + handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}), + } + now := time.Now() + + // Create manager + nc := newConnectionManager(ifce, 5, 10) + nc.HandleMonitorTick(now) + // Add an ip we have established a connection w/ to hostmap + hostinfo := nc.hostMap.AddVpnIP(vpnIP) + hostinfo.ConnectionState = &ConnectionState{ + certState: cs, + H: &noise.HandshakeState{}, + messageCounter: new(uint64), + } + + // We saw traffic out to vpnIP + nc.Out(vpnIP) + assert.NotContains(t, nc.pendingDeletion, vpnIP) + assert.Contains(t, nc.hostMap.Hosts, vpnIP) + // Move ahead 5s. Nothing should happen + next_tick := now.Add(5 * time.Second) + nc.HandleMonitorTick(next_tick) + nc.HandleDeletionTick(next_tick) + // Move ahead 6s. We haven't heard back + next_tick = now.Add(6 * time.Second) + nc.HandleMonitorTick(next_tick) + nc.HandleDeletionTick(next_tick) + // This host should now be up for deletion + assert.Contains(t, nc.pendingDeletion, vpnIP) + assert.Contains(t, nc.hostMap.Hosts, vpnIP) + // Move ahead some more + next_tick = now.Add(45 * time.Second) + nc.HandleMonitorTick(next_tick) + nc.HandleDeletionTick(next_tick) + // The host should be evicted + assert.NotContains(t, nc.pendingDeletion, vpnIP) + assert.NotContains(t, nc.hostMap.Hosts, vpnIP) + +} + +func Test_NewConnectionManagerTest2(t *testing.T) { + //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") + _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") + _, localrange, _ := net.ParseCIDR("10.1.1.1/24") + preferredRanges := []*net.IPNet{localrange} + + // Very incomplete mock objects + hostMap := NewHostMap("test", vpncidr, preferredRanges) + cs := &CertState{ + rawCertificate: []byte{}, + privateKey: []byte{}, + certificate: &cert.NebulaCertificate{}, + rawCertificateNoKey: []byte{}, + } + + lh := NewLightHouse(false, 0, []string{}, 1000, 0, &udpConn{}, false) + ifce := &Interface{ + hostMap: hostMap, + inside: &Tun{}, + outside: &udpConn{}, + certState: cs, + firewall: &Firewall{}, + lightHouse: lh, + handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}), + } + now := time.Now() + + // Create manager + nc := newConnectionManager(ifce, 5, 10) + nc.HandleMonitorTick(now) + // Add an ip we have established a connection w/ to hostmap + hostinfo := nc.hostMap.AddVpnIP(vpnIP) + hostinfo.ConnectionState = &ConnectionState{ + certState: cs, + H: &noise.HandshakeState{}, + messageCounter: new(uint64), + } + + // We saw traffic out to vpnIP + nc.Out(vpnIP) + assert.NotContains(t, nc.pendingDeletion, vpnIP) + assert.Contains(t, nc.hostMap.Hosts, vpnIP) + // Move ahead 5s. Nothing should happen + next_tick := now.Add(5 * time.Second) + nc.HandleMonitorTick(next_tick) + nc.HandleDeletionTick(next_tick) + // Move ahead 6s. We haven't heard back + next_tick = now.Add(6 * time.Second) + nc.HandleMonitorTick(next_tick) + nc.HandleDeletionTick(next_tick) + // This host should now be up for deletion + assert.Contains(t, nc.pendingDeletion, vpnIP) + assert.Contains(t, nc.hostMap.Hosts, vpnIP) + // We heard back this time + nc.In(vpnIP) + // Move ahead some more + next_tick = now.Add(45 * time.Second) + nc.HandleMonitorTick(next_tick) + nc.HandleDeletionTick(next_tick) + // The host should be evicted + assert.NotContains(t, nc.pendingDeletion, vpnIP) + assert.Contains(t, nc.hostMap.Hosts, vpnIP) + +} diff --git a/connection_state.go b/connection_state.go new file mode 100644 index 0000000..2583745 --- /dev/null +++ b/connection_state.go @@ -0,0 +1,75 @@ +package nebula + +import ( + "crypto/rand" + "encoding/json" + "sync" + + "github.com/flynn/noise" + "github.com/slackhq/nebula/cert" +) + +const ReplayWindow = 1024 + +type ConnectionState struct { + eKey *NebulaCipherState + dKey *NebulaCipherState + H *noise.HandshakeState + certState *CertState + peerCert *cert.NebulaCertificate + initiator bool + messageCounter *uint64 + window *Bits + queueLock sync.Mutex + writeLock sync.Mutex + ready bool +} + +func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { + cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256) + if f.cipher == "chachapoly" { + cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) + } + + curCertState := f.certState + static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey} + + b := NewBits(ReplayWindow) + // Clear out bit 0, we never transmit it and we don't want it showing as packet loss + b.Update(0) + + hs, err := noise.NewHandshakeState(noise.Config{ + CipherSuite: cs, + Random: rand.Reader, + Pattern: pattern, + Initiator: initiator, + StaticKeypair: static, + PresharedKey: psk, + PresharedKeyPlacement: pskStage, + }) + if err != nil { + return nil + } + + // The queue and ready params prevent a counter race that would happen when + // sending stored packets and simultaneously accepting new traffic. + ci := &ConnectionState{ + H: hs, + initiator: initiator, + window: b, + ready: false, + certState: curCertState, + messageCounter: new(uint64), + } + + return ci +} + +func (cs *ConnectionState) MarshalJSON() ([]byte, error) { + return json.Marshal(m{ + "certificate": cs.peerCert, + "initiator": cs.initiator, + "message_counter": cs.messageCounter, + "ready": cs.ready, + }) +} diff --git a/dns_server.go b/dns_server.go new file mode 100644 index 0000000..705d7b9 --- /dev/null +++ b/dns_server.go @@ -0,0 +1,125 @@ +package nebula + +import ( + "fmt" + "net" + "strconv" + "sync" + + "github.com/miekg/dns" +) + +// This whole thing should be rewritten to use context + +var dnsR *dnsRecords + +type dnsRecords struct { + sync.RWMutex + dnsMap map[string]string + hostMap *HostMap +} + +func newDnsRecords(hostMap *HostMap) *dnsRecords { + return &dnsRecords{ + dnsMap: make(map[string]string), + hostMap: hostMap, + } +} + +func (d *dnsRecords) Query(data string) string { + d.RLock() + if r, ok := d.dnsMap[data]; ok { + d.RUnlock() + return r + } + d.RUnlock() + return "" +} + +func (d *dnsRecords) QueryCert(data string) string { + ip := net.ParseIP(data[:len(data)-1]) + if ip == nil { + return "" + } + iip := ip2int(ip) + hostinfo, err := d.hostMap.QueryVpnIP(iip) + if err != nil { + return "" + } + q := hostinfo.GetCert() + if q == nil { + return "" + } + cert := q.Details + c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAFter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer) + return c +} + +func (d *dnsRecords) Add(host, data string) { + d.Lock() + d.dnsMap[host] = data + d.Unlock() +} + +func parseQuery(m *dns.Msg, w dns.ResponseWriter) { + for _, q := range m.Question { + switch q.Qtype { + case dns.TypeA: + l.Debugf("Query for A %s", q.Name) + ip := dnsR.Query(q.Name) + if ip != "" { + rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip)) + if err == nil { + m.Answer = append(m.Answer, rr) + } + } + case dns.TypeTXT: + a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) + b := net.ParseIP(a) + // We don't answer these queries from non nebula nodes or localhost + //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR) + if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" { + return + } + l.Debugf("Query for TXT %s", q.Name) + ip := dnsR.QueryCert(q.Name) + if ip != "" { + rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) + if err == nil { + m.Answer = append(m.Answer, rr) + } + } + } + } +} + +func handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + m.Compress = false + + switch r.Opcode { + case dns.OpcodeQuery: + parseQuery(m, w) + } + + w.WriteMsg(m) +} + +func dnsMain(hostMap *HostMap) { + + dnsR = newDnsRecords(hostMap) + + // attach request handler func + dns.HandleFunc(".", handleDnsRequest) + + // start server + port := 53 + server := &dns.Server{Addr: ":" + strconv.Itoa(port), Net: "udp"} + l.Debugf("Starting DNS responder at %d\n", port) + err := server.ListenAndServe() + defer server.Shutdown() + if err != nil { + l.Errorf("Failed to start server: %s\n ", err.Error()) + } +} diff --git a/dns_server_test.go b/dns_server_test.go new file mode 100644 index 0000000..830dc8a --- /dev/null +++ b/dns_server_test.go @@ -0,0 +1,19 @@ +package nebula + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestParsequery(t *testing.T) { + //TODO: This test is basically pointless + hostMap := &HostMap{} + ds := newDnsRecords(hostMap) + ds.Add("test.com.com", "1.2.3.4") + + m := new(dns.Msg) + m.SetQuestion("test.com.com", dns.TypeA) + + //parseQuery(m) +} diff --git a/examples/config.yaml b/examples/config.yaml new file mode 100644 index 0000000..792298f --- /dev/null +++ b/examples/config.yaml @@ -0,0 +1,160 @@ +# This is the nebula example configuration file. You must edit, at a minimum, the static_host_map, lighthouse, and firewall sections + + +# PKI defines the location of credentials for this node. Each of these can also be inlined by using the yaml ": |" syntax. +pki: + ca: /etc/nebula/ca.crt + cert: /etc/nebula/host.crt + key: /etc/nebula/host.key + #blacklist is a list of certificate fingerprints that we will refuse to talk to + #blacklist: + # - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72 + +# The static host map defines a set of hosts with with fixed IP addresses on the internet (or any network). +# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. +# The syntax is: +# "{nebula ip}": ["{routable ip/dns name}:{routable port}"] +# Example, if your lighthouse has the nebula IP of 192.168.100.1 and has the real ip address of 100.64.22.11 and runs on port 4242: +static_host_map: + "192.168.100.1": ["100.64.22.11:4242"] + + +lighthouse: + # am_lighthouse is used to enable lighthouse functionality for a node. This should ONLY be true on nodes + # you have configured to be lighthouses in your network + am_lighthouse: false + # serve_dns optionally starts a dns listener that responds to various queries and can even be + # delegated to for resolution + #serve_dns: false + # interval is the number of seconds between updates from this node to a lighthouse. + # during updates, a node sends information about its current IP addresses to each node. + interval: 60 + # hosts is a list of lighthouse hosts this node should report to and query from + # IMPORTANT: THIS SHOULD BE EMPTY ON LIGHTHOUSE NODES + hosts: + - "192.168.100.1" + +# Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined, +# however using port 0 will dynamically assign a port and is recommended for roaming nodes. +listen: + host: 0.0.0.0 + port: 4242 + # Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg) + # default is 64, does not support reload + #batch: 64 + # Configure socket buffers for the udp side (outside), leave unset to use the system defaults. Values will be doubled by the kernel + # Default is net.core.rmem_default and net.core.wmem_default (/proc/sys/net/core/rmem_default and /proc/sys/net/core/rmem_default) + # Maximum is limited by memory in the system, SO_RCVBUFFORCE and SO_SNDBUFFORCE is used to avoid having to raise the system wide + # max, net.core.rmem_max and net.core.wmem_max + #read_buffer: 10485760 + #write_buffer: 10485760 + + +# Local range is used to define a hint about the local network range, which speeds up discovering the fastest +# path to a network adjacent nebula node. +#local_range: "172.16.0.0/24" + +# Handshake mac is an optional network-wide handshake authentication step that is used to prevent nebula from +# responding to handshakes from nodes not in possession of the shared secret. This is primarily used to prevent +# detection of nebula nodes when someone is scanning a network. +#handshake_mac: + #key: "DONOTUSETHISKEY" + # You can define multiple accepted keys + #accepted_keys: + #- "DONOTUSETHISKEY" + #- "dontusethiseither" + +# sshd can expose informational and administrative functions via ssh this is a +#sshd: + # Toggles the feature + #enabled: true + # Host and port to listen on, port 22 is not allowed for your safety + #listen: 127.0.0.1:2222 + # A file containing the ssh host private key to use + # A decent way to generate one: ssh-keygen -t ed25519 -f ssh_host_ed25519_key -N "" < /dev/null + #host_key: ./ssh_host_ed25519_key + # A file containing a list of authorized public keys + #authorized_users: + #- user: steeeeve + # keys can be an array of strings or single string + #keys: + #- "ssh public key string" + +# Configure the private interface. Note: addr is baked into the nebula certificate +tun: + # Name of the device + dev: nebula1 + # Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert + drop_local_broadcast: false + # Toggles forwarding of multicast packets + drop_multicast: false + # Sets the transmit queue length, if you notice lots of transmit drops on the tun it may help to raise this number. Default is 500 + tx_queue: 500 + # Default MTU for every packet, safe setting is (and the default) 1300 for internet based traffic + mtu: 1300 + # Route based MTU overrides, you have known vpn ip paths that can support larger MTUs you can increase/decrease them here + routes: + #- mtu: 8800 + # route: 10.0.0.0/16 + +# TODO +# Configure logging level +logging: + # panic, fatal, error, warning, info, or debug. Default is info + level: info + # json or text formats currently available. Default is text + format: text + +#stats: + #type: graphite + #prefix: nebula + #protocol: tcp + #host: 127.0.0.1:9999 + #interval: 10s + + #type: prometheus + #listen: 127.0.0.1:8080 + #path: /metrics + #namespace: prometheusns + #subsystem: nebula + #interval: 10s + +# Nebula security group configuration +firewall: + conntrack: + tcp_timeout: 120h + udp_timeout: 3m + default_timeout: 10m + max_connections: 100000 + + # The firewall is default deny. There is no way to write a deny rule. + # Rules are comprised of a protocol, port, and one or more of host, group, or CIDR + # Logical evaluation is roughly: port AND proto AND ca_sha AND ca_name AND (host OR group OR groups OR cidr) + # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available). + # code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any` + # proto: `any`, `tcp`, `udp`, or `icmp` + # host: `any` or a literal hostname, ie `test-host` + # group: `any` or a literal group name, ie `default-group` + # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass + # cidr: a CIDR, `0.0.0.0/0` is any. + # ca_name: An issuing CA name + # ca_sha: An issuing CA shasum + + outbound: + # Allow all outbound traffic from this node + - port: any + proto: any + host: any + + inbound: + # Allow icmp between any nebula hosts + - port: any + proto: icmp + host: any + + # Allow tcp/443 from any host with BOTH laptop and home group + - port: 443 + proto: tcp + groups: + - laptop + - home diff --git a/examples/quickstart-vagrant/README.md b/examples/quickstart-vagrant/README.md new file mode 100644 index 0000000..09942fc --- /dev/null +++ b/examples/quickstart-vagrant/README.md @@ -0,0 +1,154 @@ +# Quickstart Guide + +This guide is intended to bring up a vagrant environment with 1 lighthouse and 2 generic hosts running nebula. + +## Pre-requisites + +There are two pre-requisites prior to bringing up the vagrant environment + + - build the binaries locally for the vagrant deploy + - create a virtualenv for ansible + +### Building the binaries + +Build the `nebula` and `nebula-cert` binaries for vagrant by doing the following + +`make bin-vagrant` (under the src directory with Makefile) + +For convenience, ansible will run this for you in every deploy (see `ansible/playbook.yml`) + +### Creating the virtualenv + +Within the `quickstart/` directory, do the following + +``` +# make a virtual environment +virtualenv venv + +# get into the virtualenv +source venv/bin/activate + +# install ansible +pip install -r requirements.yml +``` + +## Bringing up the vagrant environment + +A plugin that is used for the Vagrant environment is `vagrant-hostmanager` + +To install, run + +``` +vagrant plugin install vagrant-hostmanager +``` + +All hosts within the Vagrantfile are brought up with + +`vagrant up` + +Once the boxes are up, go into the `ansible/` directory and deploy the playbook by running + +`ansible-playbook playbook.yml -i inventory -u vagrant` + +## Testing within the vagrant env + +Once the ansible run is done, hop onto a vagrant box + +`vagrant ssh generic1.vagrant` + +or specifically + +`ssh vagrant@` (password for the vagrant user on the boxes is `vagrant`) + +See `/etc/nebula/config.yml` on a box for firewall rules. + +To see full handshakes and hostmaps, change the logging config of `/etc/nebula/config.yml` on the vagrant boxes from +info to debug. + +You can watch nebula logs by running + +``` +sudo journalctl -fu nebula +``` + +Refer to the nebula src code directory's README for further instructions on configuring nebula. + +## Troubleshooting + +### Is nebula up and running? + +Run and verify that + +``` +ifconfig +``` + +shows you an interface with the name `nebula1` being up. + +``` +vagrant@generic1:~$ ifconfig nebula1 +nebula1: flags=4305 mtu 1300 + inet 10.168.91.210 netmask 255.128.0.0 destination 10.168.91.210 + inet6 fe80::aeaf:b105:e6dc:936c prefixlen 64 scopeid 0x20 + unspec 00-00-00-00-00-00-00-00-00-00-00-00-00-00-00-00 txqueuelen 500 (UNSPEC) + RX packets 2 bytes 168 (168.0 B) + RX errors 0 dropped 0 overruns 0 frame 0 + TX packets 11 bytes 600 (600.0 B) + TX errors 0 dropped 0 overruns 0 carrier 0 collisions 0 +``` + +### Connectivity + +Are you able to ping other boxes on the private nebula network? + +The following are the private nebula ip addresses of the vagrant env + +``` +generic1.vagrant [nebula_ip] 10.168.91.210 +generic2.vagrant [nebula_ip] 10.168.91.220 +lighthouse1.vagrant [nebula_ip] 10.168.91.230 +``` + +Try pinging generic1.vagrant to and from any other box using its nebula ip above. + +Double check the nebula firewall rules under /etc/nebula/config.yml to make sure that connectivity is allowed for your use-case if on a specific port. + +``` +vagrant@lighthouse1:~$ grep -A21 firewall /etc/nebula/config.yml +firewall: + conntrack: + tcp_timeout: 12m + udp_timeout: 3m + default_timeout: 10m + max_connections: 100,000 + + inbound: + - proto: icmp + port: any + host: any + - proto: any + port: 22 + host: any + - proto: any + port: 53 + host: any + + outbound: + - proto: any + port: any + host: any +``` diff --git a/examples/quickstart-vagrant/Vagrantfile b/examples/quickstart-vagrant/Vagrantfile new file mode 100644 index 0000000..ab9408f --- /dev/null +++ b/examples/quickstart-vagrant/Vagrantfile @@ -0,0 +1,40 @@ +Vagrant.require_version ">= 2.2.6" + +nodes = [ + { :hostname => 'generic1.vagrant', :ip => '172.11.91.210', :box => 'bento/ubuntu-18.04', :ram => '512', :cpus => 1}, + { :hostname => 'generic2.vagrant', :ip => '172.11.91.220', :box => 'bento/ubuntu-18.04', :ram => '512', :cpus => 1}, + { :hostname => 'lighthouse1.vagrant', :ip => '172.11.91.230', :box => 'bento/ubuntu-18.04', :ram => '512', :cpus => 1}, +] + +Vagrant.configure("2") do |config| + + config.ssh.insert_key = false + + if Vagrant.has_plugin?('vagrant-cachier') + config.cache.enable :apt + else + printf("** Install vagrant-cachier plugin to speedup deploy: `vagrant plugin install vagrant-cachier`.**\n") + end + + if Vagrant.has_plugin?('vagrant-hostmanager') + config.hostmanager.enabled = true + config.hostmanager.manage_host = true + config.hostmanager.include_offline = true + else + config.vagrant.plugins = "vagrant-hostmanager" + end + + nodes.each do |node| + config.vm.define node[:hostname] do |node_config| + node_config.vm.box = node[:box] + node_config.vm.hostname = node[:hostname] + node_config.vm.network :private_network, ip: node[:ip] + node_config.vm.provider :virtualbox do |vb| + vb.memory = node[:ram] + vb.cpus = node[:cpus] + vb.customize ["modifyvm", :id, "--natdnshostresolver1", "on"] + vb.customize ['guestproperty', 'set', :id, '/VirtualBox/GuestAdd/VBoxService/--timesync-set-threshold', 10000] + end + end + end +end diff --git a/examples/quickstart-vagrant/ansible/ansible.cfg b/examples/quickstart-vagrant/ansible/ansible.cfg new file mode 100644 index 0000000..518a4f1 --- /dev/null +++ b/examples/quickstart-vagrant/ansible/ansible.cfg @@ -0,0 +1,4 @@ +[defaults] +host_key_checking = False +private_key_file = ~/.vagrant.d/insecure_private_key +become = yes diff --git a/examples/quickstart-vagrant/ansible/filter_plugins/to_nebula_ip.py b/examples/quickstart-vagrant/ansible/filter_plugins/to_nebula_ip.py new file mode 100644 index 0000000..d7bb31a --- /dev/null +++ b/examples/quickstart-vagrant/ansible/filter_plugins/to_nebula_ip.py @@ -0,0 +1,21 @@ +#!/usr/bin/python + + +class FilterModule(object): + def filters(self): + return { + 'to_nebula_ip': self.to_nebula_ip, + 'map_to_nebula_ips': self.map_to_nebula_ips, + } + + def to_nebula_ip(self, ip_str): + ip_list = map(int, ip_str.split(".")) + ip_list[0] = 10 + ip_list[1] = 168 + ip = '.'.join(map(str, ip_list)) + return ip + + def map_to_nebula_ips(self, ip_strs): + ip_list = [ self.to_nebula_ip(ip_str) for ip_str in ip_strs ] + ips = ', '.join(ip_list) + return ips diff --git a/examples/quickstart-vagrant/ansible/inventory b/examples/quickstart-vagrant/ansible/inventory new file mode 100644 index 0000000..0bae407 --- /dev/null +++ b/examples/quickstart-vagrant/ansible/inventory @@ -0,0 +1,11 @@ +[all] +generic1.vagrant +generic2.vagrant +lighthouse1.vagrant + +[generic] +generic1.vagrant +generic2.vagrant + +[lighthouse] +lighthouse1.vagrant diff --git a/examples/quickstart-vagrant/ansible/playbook.yml b/examples/quickstart-vagrant/ansible/playbook.yml new file mode 100644 index 0000000..a159d0b --- /dev/null +++ b/examples/quickstart-vagrant/ansible/playbook.yml @@ -0,0 +1,20 @@ +--- +- name: test connection to vagrant boxes + hosts: all + tasks: + - debug: msg=ok + +- name: build nebula binaries locally + connection: local + hosts: localhost + tasks: + - command: chdir=../../../ make bin-vagrant + tags: + - build-nebula + +- name: install nebula on all vagrant hosts + hosts: all + become: yes + gather_facts: yes + roles: + - nebula diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/defaults/main.yml b/examples/quickstart-vagrant/ansible/roles/nebula/defaults/main.yml new file mode 100644 index 0000000..f8e7a99 --- /dev/null +++ b/examples/quickstart-vagrant/ansible/roles/nebula/defaults/main.yml @@ -0,0 +1,3 @@ +--- +# defaults file for nebula +nebula_config_directory: "/etc/nebula/" diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/files/systemd.nebula.service b/examples/quickstart-vagrant/ansible/roles/nebula/files/systemd.nebula.service new file mode 100644 index 0000000..13c5ff2 --- /dev/null +++ b/examples/quickstart-vagrant/ansible/roles/nebula/files/systemd.nebula.service @@ -0,0 +1,15 @@ +[Unit] +Description=nebula +Wants=basic.target +After=basic.target network.target + +[Service] +SyslogIdentifier=nebula +StandardOutput=syslog +StandardError=syslog +ExecReload=/bin/kill -HUP $MAINPID +ExecStart=/usr/local/bin/nebula -config /etc/nebula/config.yml +Restart=always + +[Install] +WantedBy=multi-user.target diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt b/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt new file mode 100644 index 0000000..6148687 --- /dev/null +++ b/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt @@ -0,0 +1,5 @@ +-----BEGIN NEBULA CERTIFICATE----- +CkAKDm5lYnVsYSB0ZXN0IENBKNXC1NYFMNXIhO0GOiCmVYeZ9tkB4WEnawmkrca+ +hsAg9otUFhpAowZeJ33KVEABEkAORybHQUUyVFbKYzw0JHfVzAQOHA4kwB1yP9IV +KpiTw9+ADz+wA+R5tn9B+L8+7+Apc+9dem4BQULjA5mRaoYN +-----END NEBULA CERTIFICATE----- diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key b/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key new file mode 100644 index 0000000..394043c --- /dev/null +++ b/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key @@ -0,0 +1,4 @@ +-----BEGIN NEBULA ED25519 PRIVATE KEY----- +FEXZKMSmg8CgIODR0ymUeNT3nbnVpMi7nD79UgkCRHWmVYeZ9tkB4WEnawmkrca+ +hsAg9otUFhpAowZeJ33KVA== +-----END NEBULA ED25519 PRIVATE KEY----- diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/handlers/main.yml b/examples/quickstart-vagrant/ansible/roles/nebula/handlers/main.yml new file mode 100644 index 0000000..0e09599 --- /dev/null +++ b/examples/quickstart-vagrant/ansible/roles/nebula/handlers/main.yml @@ -0,0 +1,5 @@ +--- +# handlers file for nebula + +- name: restart nebula + service: name=nebula state=restarted diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/tasks/main.yml b/examples/quickstart-vagrant/ansible/roles/nebula/tasks/main.yml new file mode 100644 index 0000000..b96e7c7 --- /dev/null +++ b/examples/quickstart-vagrant/ansible/roles/nebula/tasks/main.yml @@ -0,0 +1,56 @@ +--- +# tasks file for nebula + +- name: get the vagrant network interface and set fact + set_fact: + vagrant_ifce: "ansible_{{ ansible_interfaces | difference(['lo',ansible_default_ipv4.alias]) | sort | first }}" + tags: + - nebula-conf + +- name: install built nebula binary + copy: src=../../../../../{{ item }} dest=/usr/local/bin mode=0755 + with_items: + - nebula + - nebula-cert + +- name: create nebula config directory + file: path="{{ nebula_config_directory }}" state=directory mode=0755 + +- name: temporarily copy over root.crt and root.key to sign + copy: src={{ item }} dest=/opt/{{ item }} + with_items: + - vagrant-test-ca.key + - vagrant-test-ca.crt + +- name: sign using the root key + command: nebula-cert sign -ca-crt /opt/vagrant-test-ca.crt -ca-key /opt/vagrant-test-ca.key -duration 4320h -groups vagrant -ip {{ hostvars[inventory_hostname][vagrant_ifce]['ipv4']['address'] | to_nebula_ip }}/9 -name {{ ansible_hostname }}.nebula -out-crt /etc/nebula/host.crt -out-key /etc/nebula/host.key + +- name: remove root.key used to sign + file: dest=/opt/{{ item }} state=absent + with_items: + - vagrant-test-ca.key + +- name: write the content of the trusted ca certificate + copy: src="vagrant-test-ca.crt" dest="/etc/nebula/vagrant-test-ca.crt" + notify: restart nebula + +- name: Create config directory + file: path="{{ nebula_config_directory }}" owner=root group=root mode=0755 state=directory + +- name: nebula config + template: src=config.yml.j2 dest="/etc/nebula/config.yml" mode=0644 owner=root group=root + notify: restart nebula + tags: + - nebula-conf + +- name: nebula systemd + copy: src=systemd.nebula.service dest="/etc/systemd/system/nebula.service" mode=0644 owner=root group=root + register: addconf + notify: restart nebula + +- name: maybe reload systemd + shell: systemctl daemon-reload + when: addconf.changed + +- name: nebula running + service: name="nebula" state=started enabled=yes diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/templates/config.yml.j2 b/examples/quickstart-vagrant/ansible/roles/nebula/templates/config.yml.j2 new file mode 100644 index 0000000..708409a --- /dev/null +++ b/examples/quickstart-vagrant/ansible/roles/nebula/templates/config.yml.j2 @@ -0,0 +1,84 @@ +pki: + ca: /etc/nebula/vagrant-test-ca.crt + cert: /etc/nebula/host.crt + key: /etc/nebula/host.key + +# Port Nebula will be listening on +listen: + host: 0.0.0.0 + port: 4242 + +# sshd can expose informational and administrative functions via ssh +sshd: + # Toggles the feature + enabled: true + # Host and port to listen on + listen: 127.0.0.1:2222 + # A file containing the ssh host private key to use + host_key: /etc/ssh/ssh_host_ed25519_key + # A file containing a list of authorized public keys + authorized_users: +{% for user in nebula_users %} + - user: {{ user.name }} + keys: +{% for key in user.ssh_auth_keys %} + - "{{ key }}" +{% endfor %} +{% endfor %} + +local_range: 10.168.0.0/16 + +static_host_map: +# lighthouse + {{ hostvars[groups['lighthouse'][0]][vagrant_ifce]['ipv4']['address'] | to_nebula_ip }}: ["{{ hostvars[groups['lighthouse'][0]][vagrant_ifce]['ipv4']['address']}}:4242"] + +default_route: "0.0.0.0" + +lighthouse: +{% if 'lighthouse' in group_names %} + am_lighthouse: true + serve_dns: true +{% else %} + am_lighthouse: false +{% endif %} + interval: 60 + hosts: + - {{ hostvars[groups['lighthouse'][0]][vagrant_ifce]['ipv4']['address'] | to_nebula_ip }} + +# Configure the private interface +tun: + dev: nebula1 + # Sets MTU of the tun dev. + # MTU of the tun must be smaller than the MTU of the eth0 interface + mtu: 1300 + +# TODO +# Configure logging level +logging: + level: info + format: json + +firewall: + conntrack: + tcp_timeout: 12m + udp_timeout: 3m + default_timeout: 10m + max_connections: 100,000 + + inbound: + - proto: icmp + port: any + host: any + - proto: any + port: 22 + host: any +{% if "lighthouse" in groups %} + - proto: any + port: 53 + host: any +{% endif %} + + outbound: + - proto: any + port: any + host: any diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/vars/main.yml b/examples/quickstart-vagrant/ansible/roles/nebula/vars/main.yml new file mode 100644 index 0000000..7a3ae5d --- /dev/null +++ b/examples/quickstart-vagrant/ansible/roles/nebula/vars/main.yml @@ -0,0 +1,7 @@ +--- +# vars file for nebula + +nebula_users: + - name: user1 + ssh_auth_keys: + - "ed25519 place-your-ssh-public-key-here" diff --git a/examples/quickstart-vagrant/requirements.yml b/examples/quickstart-vagrant/requirements.yml new file mode 100644 index 0000000..90d4055 --- /dev/null +++ b/examples/quickstart-vagrant/requirements.yml @@ -0,0 +1 @@ +ansible diff --git a/examples/service_scripts/nebula.init.d.sh b/examples/service_scripts/nebula.init.d.sh new file mode 100644 index 0000000..34fe17e --- /dev/null +++ b/examples/service_scripts/nebula.init.d.sh @@ -0,0 +1,51 @@ +#!/bin/sh +### BEGIN INIT INFO +# Provides: nebula +# Required-Start: $local_fs $network +# Required-Stop: $local_fs $network +# Default-Start: 2 3 4 5 +# Default-Stop: 0 1 6 +# Description: nebula mesh vpn client +### END INIT INFO + +SCRIPT="/usr/local/bin/nebula -config /etc/nebula/config.yml" +RUNAS=root + +PIDFILE=/var/run/nebula.pid +LOGFILE=/var/log/nebula.log + +start() { + if [ -f $PIDFILE ] && kill -0 $(cat $PIDFILE); then + echo 'Service already running' >&2 + return 1 + fi + echo 'Starting nebula service…' >&2 + local CMD="$SCRIPT &> \"$LOGFILE\" & echo \$!" + su -c "$CMD" $RUNAS > "$PIDFILE" + echo 'Service started' >&2 +} + +stop() { + if [ ! -f "$PIDFILE" ] || ! kill -0 $(cat "$PIDFILE"); then + echo 'Service not running' >&2 + return 1 + fi + echo 'Stopping nebula service…' >&2 + kill -15 $(cat "$PIDFILE") && rm -f "$PIDFILE" + echo 'Service stopped' >&2 +} + +case "$1" in + start) + start + ;; + stop) + stop + ;; + restart) + stop + start + ;; + *) + echo "Usage: $0 {start|stop|restart}" +esac diff --git a/examples/service_scripts/nebula.service b/examples/service_scripts/nebula.service new file mode 100644 index 0000000..13c5ff2 --- /dev/null +++ b/examples/service_scripts/nebula.service @@ -0,0 +1,15 @@ +[Unit] +Description=nebula +Wants=basic.target +After=basic.target network.target + +[Service] +SyslogIdentifier=nebula +StandardOutput=syslog +StandardError=syslog +ExecReload=/bin/kill -HUP $MAINPID +ExecStart=/usr/local/bin/nebula -config /etc/nebula/config.yml +Restart=always + +[Install] +WantedBy=multi-user.target diff --git a/firewall.go b/firewall.go new file mode 100644 index 0000000..0ada61b --- /dev/null +++ b/firewall.go @@ -0,0 +1,789 @@ +package nebula + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "net" + "sync" + "time" + + "crypto/sha256" + "encoding/hex" + "errors" + "reflect" + "strconv" + "strings" + + "github.com/rcrowley/go-metrics" + "github.com/slackhq/nebula/cert" +) + +const ( + fwProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever + fwProtoTCP = 6 + fwProtoUDP = 17 + fwProtoICMP = 1 + + fwPortAny = 0 // Special value for matching `port: any` + fwPortFragment = -1 // Special value for matching `port: fragment` +) + +const tcpACK = 0x10 +const tcpFIN = 0x01 + +type FirewallInterface interface { + AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error +} + +type conn struct { + Expires time.Time // Time when this conntrack entry will expire + Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack + Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set +} + +// TODO: need conntrack max tracked connections handling +type Firewall struct { + Conns map[FirewallPacket]*conn + + InRules *FirewallTable + OutRules *FirewallTable + + //TODO: we should have many more options for TCP, an option for ICMP, and mimic the kernel a bit better + // https://www.kernel.org/doc/Documentation/networking/nf_conntrack-sysctl.txt + TCPTimeout time.Duration //linux: 5 days max + UDPTimeout time.Duration //linux: 180s max + DefaultTimeout time.Duration //linux: 600s + + TimerWheel *TimerWheel + + // Used to ensure we don't emit local packets for ips we don't own + localIps *CIDRTree + + connMutex sync.Mutex + rules string + + trackTCPRTT bool + metricTCPRTT metrics.Histogram +} + +type FirewallTable struct { + TCP firewallPort + UDP firewallPort + ICMP firewallPort + AnyProto firewallPort +} + +func newFirewallTable() *FirewallTable { + return &FirewallTable{ + TCP: firewallPort{}, + UDP: firewallPort{}, + ICMP: firewallPort{}, + AnyProto: firewallPort{}, + } +} + +type FirewallRule struct { + // Any makes Hosts, Groups, and CIDR irrelevant. CAName and CASha still need to be checked + Any bool + Hosts map[string]struct{} + Groups [][]string + CIDR *CIDRTree + CANames map[string]struct{} + CAShas map[string]struct{} +} + +// Even though ports are uint16, int32 maps are faster for lookup +// Plus we can use `-1` for fragment rules +type firewallPort map[int32]*FirewallRule + +type FirewallPacket struct { + LocalIP uint32 + RemoteIP uint32 + LocalPort uint16 + RemotePort uint16 + Protocol uint8 + Fragment bool +} + +func (fp *FirewallPacket) Copy() *FirewallPacket { + return &FirewallPacket{ + LocalIP: fp.LocalIP, + RemoteIP: fp.RemoteIP, + LocalPort: fp.LocalPort, + RemotePort: fp.RemotePort, + Protocol: fp.Protocol, + Fragment: fp.Fragment, + } +} + +func (fp FirewallPacket) MarshalJSON() ([]byte, error) { + var proto string + switch fp.Protocol { + case fwProtoTCP: + proto = "tcp" + case fwProtoICMP: + proto = "icmp" + case fwProtoUDP: + proto = "udp" + default: + proto = fmt.Sprintf("unknown %v", fp.Protocol) + } + return json.Marshal(m{ + "LocalIP": int2ip(fp.LocalIP).String(), + "RemoteIP": int2ip(fp.RemoteIP).String(), + "LocalPort": fp.LocalPort, + "RemotePort": fp.RemotePort, + "Protocol": proto, + "Fragment": fp.Fragment, + }) +} + +// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. +func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall { + //TODO: error on 0 duration + var min, max time.Duration + + if tcpTimeout < UDPTimeout { + min = tcpTimeout + max = UDPTimeout + } else { + min = UDPTimeout + max = tcpTimeout + } + + if defaultTimeout < min { + min = defaultTimeout + } else if defaultTimeout > max { + max = defaultTimeout + } + + localIps := NewCIDRTree() + for _, ip := range c.Details.Ips { + localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) + } + + for _, n := range c.Details.Subnets { + localIps.AddCIDR(n, struct{}{}) + } + + return &Firewall{ + Conns: make(map[FirewallPacket]*conn), + InRules: newFirewallTable(), + OutRules: newFirewallTable(), + TimerWheel: NewTimerWheel(min, max), + TCPTimeout: tcpTimeout, + UDPTimeout: UDPTimeout, + DefaultTimeout: defaultTimeout, + localIps: localIps, + metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)), + } +} + +func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) { + fw := NewFirewall( + c.GetDuration("firewall.conntrack.tcp_timeout", time.Duration(time.Minute*12)), + c.GetDuration("firewall.conntrack.udp_timeout", time.Duration(time.Minute*3)), + c.GetDuration("firewall.conntrack.default_timeout", time.Duration(time.Minute*10)), + nc, + //TODO: max_connections + ) + + err := AddFirewallRulesFromConfig(false, c, fw) + if err != nil { + return nil, err + } + + err = AddFirewallRulesFromConfig(true, c, fw) + if err != nil { + return nil, err + } + + return fw, nil +} + +// AddRule properly creates the in memory rule structure for a firewall table. +func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { + + // We need this rule string because we generate a hash. Removing this will break firewall reload. + ruleString := fmt.Sprintf( + "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s", + incoming, proto, startPort, endPort, groups, host, ip, caName, caSha, + ) + f.rules += ruleString + "\n" + + direction := "incoming" + if !incoming { + direction = "outgoing" + } + l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": ip, "caName": caName, "caSha": caSha}). + Info("Firewall rule added") + + var ( + ft *FirewallTable + fp firewallPort + ) + + if incoming { + ft = f.InRules + } else { + ft = f.OutRules + } + + switch proto { + case fwProtoTCP: + fp = ft.TCP + case fwProtoUDP: + fp = ft.UDP + case fwProtoICMP: + fp = ft.ICMP + case fwProtoAny: + fp = ft.AnyProto + default: + return fmt.Errorf("unknown protocol %v", proto) + } + + return fp.addRule(startPort, endPort, groups, host, ip, caName, caSha) +} + +// GetRuleHash returns a hash representation of all inbound and outbound rules +func (f *Firewall) GetRuleHash() string { + sum := sha256.Sum256([]byte(f.rules)) + return hex.EncodeToString(sum[:]) +} + +func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterface) error { + var table string + if inbound { + table = "firewall.inbound" + } else { + table = "firewall.outbound" + } + + r := config.Get(table) + if r == nil { + return nil + } + + rs, ok := r.([]interface{}) + if !ok { + return fmt.Errorf("%s failed to parse, should be an array of rules", table) + } + + for i, t := range rs { + var groups []string + r, err := convertRule(t) + if err != nil { + return fmt.Errorf("%s rule #%v; %s", table, i, err) + } + + if r.Code != "" && r.Port != "" { + return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i) + } + + if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.CAName == "" && r.CASha == "" { + return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, ca_name, or ca_sha must be provided", table, i) + } + + if len(r.Groups) > 0 { + groups = r.Groups + } + + if r.Group != "" { + // Check if we have both groups and group provided in the rule config + if len(groups) > 0 { + return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i) + } + + groups = []string{r.Group} + } + + var sPort, errPort string + if r.Code != "" { + errPort = "code" + sPort = r.Code + } else { + errPort = "port" + sPort = r.Port + } + + startPort, endPort, err := parsePort(sPort) + if err != nil { + return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err) + } + + var proto uint8 + switch r.Proto { + case "any": + proto = fwProtoAny + case "tcp": + proto = fwProtoTCP + case "udp": + proto = fwProtoUDP + case "icmp": + proto = fwProtoICMP + default: + return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) + } + + var cidr *net.IPNet + if r.Cidr != "" { + _, cidr, err = net.ParseCIDR(r.Cidr) + if err != nil { + return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err) + } + } + + err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, r.CAName, r.CASha) + if err != nil { + return fmt.Errorf("%s rule #%v; `%s`", table, i, err) + } + } + + return nil +} + +func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { + // Check if we spoke to this tuple, if we did then allow this packet + if f.inConns(packet, fp, incoming) { + return false + } + + // Make sure we are supposed to be handling this local ip address + if f.localIps.Contains(fp.LocalIP) == nil { + return true + } + + table := f.OutRules + if incoming { + table = f.InRules + } + + // We now know which firewall table to check against + if !table.match(fp, incoming, c, caPool) { + return true + } + + // We always want to conntrack since it is a faster operation + f.addConn(packet, fp, incoming) + + return false +} + +// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new +// firewall object is created +func (f *Firewall) Destroy() { + //TODO: clean references if/when needed +} + +func (f *Firewall) EmitStats() { + conntrackCount := len(f.Conns) + metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount)) +} + +func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool { + f.connMutex.Lock() + + // Purge every time we test + ep, has := f.TimerWheel.Purge() + if has { + f.evict(ep) + } + + c, ok := f.Conns[fp] + + if !ok { + f.connMutex.Unlock() + return false + } + + switch fp.Protocol { + case fwProtoTCP: + c.Expires = time.Now().Add(f.TCPTimeout) + if incoming { + f.checkTCPRTT(c, packet) + } else { + setTCPRTTTracking(c, packet) + } + case fwProtoUDP: + c.Expires = time.Now().Add(f.UDPTimeout) + default: + c.Expires = time.Now().Add(f.DefaultTimeout) + } + + f.connMutex.Unlock() + + return true +} + +func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) { + var timeout time.Duration + c := &conn{} + + switch fp.Protocol { + case fwProtoTCP: + timeout = f.TCPTimeout + if !incoming { + setTCPRTTTracking(c, packet) + } + case fwProtoUDP: + timeout = f.UDPTimeout + default: + timeout = f.DefaultTimeout + } + + f.connMutex.Lock() + if _, ok := f.Conns[fp]; !ok { + f.TimerWheel.Add(fp, timeout) + } + + c.Expires = time.Now().Add(timeout) + f.Conns[fp] = c + f.connMutex.Unlock() +} + +// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel +// Caller must own the connMutex lock! +func (f *Firewall) evict(p FirewallPacket) { + //TODO: report a stat if the tcp rtt tracking was never resolved? + // Are we still tracking this conn? + t, ok := f.Conns[p] + if !ok { + return + } + + newT := t.Expires.Sub(time.Now()) + + // Timeout is in the future, re-add the timer + if newT > 0 { + f.TimerWheel.Add(p, newT) + return + } + + // This conn is done + delete(f.Conns, p) +} + +func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { + if ft.AnyProto.match(p, incoming, c, caPool) { + return true + } + + switch p.Protocol { + case fwProtoTCP: + if ft.TCP.match(p, incoming, c, caPool) { + return true + } + case fwProtoUDP: + if ft.UDP.match(p, incoming, c, caPool) { + return true + } + case fwProtoICMP: + if ft.ICMP.match(p, incoming, c, caPool) { + return true + } + } + + return false +} + +func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { + if startPort > endPort { + return fmt.Errorf("start port was lower than end port") + } + + for i := startPort; i <= endPort; i++ { + if _, ok := fp[i]; !ok { + fp[i] = &FirewallRule{ + Groups: make([][]string, 0), + Hosts: make(map[string]struct{}), + CIDR: NewCIDRTree(), + CANames: make(map[string]struct{}), + CAShas: make(map[string]struct{}), + } + } + + if err := fp[i].addRule(groups, host, ip, caName, caSha); err != nil { + return err + } + } + + return nil +} + +func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { + // We don't have any allowed ports, bail + if fp == nil { + return false + } + + var port int32 + + if p.Fragment { + port = fwPortFragment + } else if incoming { + port = int32(p.LocalPort) + } else { + port = int32(p.RemotePort) + } + + if fp[port].match(p, c, caPool) { + return true + } + + return fp[fwPortAny].match(p, c, caPool) +} + +func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, caName string, caSha string) error { + if caName != "" { + fr.CANames[caName] = struct{}{} + } + + if caSha != "" { + fr.CAShas[caSha] = struct{}{} + } + + if fr.Any { + return nil + } + + if fr.isAny(groups, host, ip) { + fr.Any = true + // If it's any we need to wipe out any pre-existing rules to save on memory + fr.Groups = make([][]string, 0) + fr.Hosts = make(map[string]struct{}) + fr.CIDR = NewCIDRTree() + } else { + if len(groups) > 0 { + fr.Groups = append(fr.Groups, groups) + } + + if host != "" { + fr.Hosts[host] = struct{}{} + } + + if ip != nil { + fr.CIDR.AddCIDR(ip, struct{}{}) + } + } + + return nil +} + +func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool { + for _, group := range groups { + if group == "any" { + return true + } + } + + if host == "any" { + return true + } + + if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) { + return true + } + + return false +} + +func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool { + if fr == nil { + return false + } + + // CASha and CAName always need to be checked + if len(fr.CAShas) > 0 { + if _, ok := fr.CAShas[c.Details.Issuer]; !ok { + return false + } + } + + if len(fr.CANames) > 0 { + s, err := caPool.GetCAForCert(c) + if err != nil { + return false + } + if _, ok := fr.CANames[s.Details.Name]; !ok { + return false + } + } + + // Shortcut path for if groups, hosts, or cidr contained an `any` + if fr.Any { + return true + } + + // Need any of group, host, or cidr to match + for _, sg := range fr.Groups { + found := false + + for _, g := range sg { + if _, ok := c.Details.InvertedGroups[g]; !ok { + found = false + break + } + + found = true + } + + if found { + return true + } + } + + if fr.Hosts != nil { + if _, ok := fr.Hosts[c.Details.Name]; ok { + return true + } + } + + if fr.CIDR != nil && fr.CIDR.Contains(p.RemoteIP) != nil { + return true + } + + // No host, group, or cidr matched, bye bye + return false +} + +type rule struct { + Port string + Code string + Proto string + Host string + Group string + Groups []string + Cidr string + CAName string + CASha string +} + +func convertRule(p interface{}) (rule, error) { + r := rule{} + + m, ok := p.(map[interface{}]interface{}) + if !ok { + return r, errors.New("could not parse rule") + } + + toString := func(k string, m map[interface{}]interface{}) string { + v, ok := m[k] + if !ok { + return "" + } + return fmt.Sprintf("%v", v) + } + + r.Port = toString("port", m) + r.Code = toString("code", m) + r.Proto = toString("proto", m) + r.Host = toString("host", m) + r.Group = toString("group", m) + r.Cidr = toString("cidr", m) + r.CAName = toString("ca_name", m) + r.CASha = toString("ca_sha", m) + + if rg, ok := m["groups"]; ok { + switch reflect.TypeOf(rg).Kind() { + case reflect.Slice: + v := reflect.ValueOf(rg) + r.Groups = make([]string, v.Len()) + for i := 0; i < v.Len(); i++ { + r.Groups[i] = v.Index(i).Interface().(string) + } + case reflect.String: + r.Groups = []string{rg.(string)} + default: + r.Groups = []string{fmt.Sprintf("%v", rg)} + } + } + + return r, nil +} + +func parsePort(s string) (startPort, endPort int32, err error) { + if s == "any" { + startPort = fwPortAny + endPort = fwPortAny + + } else if s == "fragment" { + startPort = fwPortFragment + endPort = fwPortFragment + + } else if strings.Contains(s, `-`) { + sPorts := strings.SplitN(s, `-`, 2) + sPorts[0] = strings.Trim(sPorts[0], " ") + sPorts[1] = strings.Trim(sPorts[1], " ") + + if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" { + return 0, 0, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s) + } + + rStartPort, err := strconv.Atoi(sPorts[0]) + if err != nil { + return 0, 0, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0]) + } + + rEndPort, err := strconv.Atoi(sPorts[1]) + if err != nil { + return 0, 0, fmt.Errorf("ending range was not a number; `%s`", sPorts[1]) + } + + startPort = int32(rStartPort) + endPort = int32(rEndPort) + + if startPort == fwPortAny { + endPort = fwPortAny + } + + } else { + rPort, err := strconv.Atoi(s) + if err != nil { + return 0, 0, fmt.Errorf("was not a number; `%s`", s) + } + startPort = int32(rPort) + endPort = startPort + } + + return +} + +//TODO: write tests for these +func setTCPRTTTracking(c *conn, p []byte) { + if c.Seq != 0 { + return + } + + ihl := int(p[0]&0x0f) << 2 + + // Don't track FIN packets + if uint8(p[ihl+13])&tcpFIN != 0 { + return + } + + c.Seq = binary.BigEndian.Uint32(p[ihl+4 : ihl+8]) + c.Sent = time.Now() +} + +func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool { + if c.Seq == 0 { + return false + } + + ihl := int(p[0]&0x0f) << 2 + if uint8(p[ihl+13])&tcpACK == 0 { + return false + } + + // Deal with wrap around, signed int cuts the ack window in half + // 0 is a bad ack, no data acknowledged + // positive number is a bad ack, ack is over half the window away + if int32(c.Seq-binary.BigEndian.Uint32(p[ihl+8:ihl+12])) >= 0 { + return false + } + + f.metricTCPRTT.Update(time.Since(c.Sent).Nanoseconds()) + c.Seq = 0 + return true +} diff --git a/firewall_test.go b/firewall_test.go new file mode 100644 index 0000000..a7c22bc --- /dev/null +++ b/firewall_test.go @@ -0,0 +1,687 @@ +package nebula + +import ( + "encoding/binary" + "errors" + "github.com/rcrowley/go-metrics" + "github.com/stretchr/testify/assert" + "math" + "net" + "github.com/slackhq/nebula/cert" + "testing" + "time" +) + +func TestNewFirewall(t *testing.T) { + c := &cert.NebulaCertificate{} + fw := NewFirewall(time.Second, time.Minute, time.Hour, c) + assert.NotNil(t, fw.Conns) + assert.NotNil(t, fw.InRules) + assert.NotNil(t, fw.OutRules) + assert.NotNil(t, fw.TimerWheel) + assert.Equal(t, time.Second, fw.TCPTimeout) + assert.Equal(t, time.Minute, fw.UDPTimeout) + assert.Equal(t, time.Hour, fw.DefaultTimeout) + + assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) + assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) + assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + + fw = NewFirewall(time.Second, time.Hour, time.Minute, c) + assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) + assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + + fw = NewFirewall(time.Hour, time.Second, time.Minute, c) + assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) + assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + + fw = NewFirewall(time.Hour, time.Minute, time.Second, c) + assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) + assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + + fw = NewFirewall(time.Minute, time.Hour, time.Second, c) + assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) + assert.Equal(t, 3601, fw.TimerWheel.wheelLen) + + fw = NewFirewall(time.Minute, time.Second, time.Hour, c) + assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration) + assert.Equal(t, 3601, fw.TimerWheel.wheelLen) +} + +func TestFirewall_AddRule(t *testing.T) { + c := &cert.NebulaCertificate{} + fw := NewFirewall(time.Second, time.Minute, time.Hour, c) + assert.NotNil(t, fw.InRules) + assert.NotNil(t, fw.OutRules) + + _, ti, _ := net.ParseCIDR("1.2.3.4/32") + + assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", "")) + // Make sure an empty rule creates structure but doesn't allow anything to flow + //TODO: ideally an empty rule would return an error + assert.False(t, fw.InRules.TCP[1].Any) + assert.Empty(t, fw.InRules.TCP[1].Groups) + assert.Empty(t, fw.InRules.TCP[1].Hosts) + assert.Nil(t, fw.InRules.TCP[1].CIDR.root.left) + assert.Nil(t, fw.InRules.TCP[1].CIDR.root.right) + assert.Nil(t, fw.InRules.TCP[1].CIDR.root.value) + + fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "")) + assert.False(t, fw.InRules.UDP[1].Any) + assert.Contains(t, fw.InRules.UDP[1].Groups[0], "g1") + assert.Empty(t, fw.InRules.UDP[1].Hosts) + assert.Nil(t, fw.InRules.UDP[1].CIDR.root.left) + assert.Nil(t, fw.InRules.UDP[1].CIDR.root.right) + assert.Nil(t, fw.InRules.UDP[1].CIDR.root.value) + + fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", "")) + assert.False(t, fw.InRules.ICMP[1].Any) + assert.Empty(t, fw.InRules.ICMP[1].Groups) + assert.Contains(t, fw.InRules.ICMP[1].Hosts, "h1") + assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.left) + assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.right) + assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.value) + + fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", "")) + assert.False(t, fw.OutRules.AnyProto[1].Any) + assert.Empty(t, fw.OutRules.AnyProto[1].Groups) + assert.Empty(t, fw.OutRules.AnyProto[1].Hosts) + assert.NotNil(t, fw.OutRules.AnyProto[1].CIDR.Match(ip2int(ti.IP))) + + fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", "")) + assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") + + fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha")) + assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") + + // Set any and clear fields + fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", "")) + assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Groups[0]) + assert.Contains(t, fw.OutRules.AnyProto[0].Hosts, "h1") + assert.NotNil(t, fw.OutRules.AnyProto[0].CIDR.Match(ip2int(ti.IP))) + + // run twice just to make sure + assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) + assert.True(t, fw.OutRules.AnyProto[0].Any) + assert.Empty(t, fw.OutRules.AnyProto[0].Groups) + assert.Empty(t, fw.OutRules.AnyProto[0].Hosts) + assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.left) + assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.right) + assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.value) + + fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) + assert.True(t, fw.OutRules.AnyProto[0].Any) + + fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + _, anyIp, _ := net.ParseCIDR("0.0.0.0/0") + assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", "")) + assert.True(t, fw.OutRules.AnyProto[0].Any) + + // Test error conditions + fw = NewFirewall(time.Second, time.Minute, time.Hour, c) + assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", "")) + assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", "")) +} + +func TestFirewall_Drop(t *testing.T) { + p := FirewallPacket{ + ip2int(net.IPv4(1, 2, 3, 4)), + 101, + 10, + 90, + fwProtoUDP, + false, + } + + ipNet := net.IPNet{ + IP: net.IPv4(1, 2, 3, 4), + Mask: net.IPMask{255, 255, 255, 0}, + } + + c := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "host1", + Ips: []*net.IPNet{&ipNet}, + Groups: []string{"default-group"}, + Issuer: "signer-shasum", + }, + } + + fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) + cp := cert.NewCAPool() + + // Drop outbound + assert.True(t, fw.Drop([]byte{}, p, false, &c, cp)) + // Allow inbound + assert.False(t, fw.Drop([]byte{}, p, true, &c, cp)) + // Allow outbound because conntrack + assert.False(t, fw.Drop([]byte{}, p, false, &c, cp)) + + // test caSha assertions true + fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum")) + assert.False(t, fw.Drop([]byte{}, p, true, &c, cp)) + + // test caSha assertions false + fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum-nope")) + assert.True(t, fw.Drop([]byte{}, p, true, &c, cp)) + + // test caName true + cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} + fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-good", "")) + assert.False(t, fw.Drop([]byte{}, p, true, &c, cp)) + + // test caName false + cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} + fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-bad", "")) + assert.True(t, fw.Drop([]byte{}, p, true, &c, cp)) +} + +func BenchmarkFirewallTable_match(b *testing.B) { + ft := FirewallTable{ + TCP: firewallPort{}, + } + + _, n, _ := net.ParseCIDR("172.1.1.1/32") + ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "") + ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "") + ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "") + ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "") + ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "") + cp := cert.NewCAPool() + + b.Run("fail on proto", func(b *testing.B) { + c := &cert.NebulaCertificate{} + for n := 0; n < b.N; n++ { + ft.match(FirewallPacket{Protocol: fwProtoUDP}, true, c, cp) + } + }) + + b.Run("fail on port", func(b *testing.B) { + c := &cert.NebulaCertificate{} + for n := 0; n < b.N; n++ { + ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 1}, true, c, cp) + } + }) + + b.Run("fail all group, name, and cidr", func(b *testing.B) { + _, ip, _ := net.ParseCIDR("9.254.254.254/32") + c := &cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + InvertedGroups: map[string]struct{}{"nope": {}}, + Name: "nope", + Ips: []*net.IPNet{ip}, + }, + } + for n := 0; n < b.N; n++ { + ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp) + } + }) + + b.Run("pass on group", func(b *testing.B) { + c := &cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + InvertedGroups: map[string]struct{}{"good-group": {}}, + Name: "nope", + }, + } + for n := 0; n < b.N; n++ { + ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp) + } + }) + + b.Run("pass on name", func(b *testing.B) { + c := &cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + InvertedGroups: map[string]struct{}{"nope": {}}, + Name: "good-host", + }, + } + for n := 0; n < b.N; n++ { + ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp) + } + }) + + b.Run("pass on ip", func(b *testing.B) { + ip := ip2int(net.IPv4(172, 1, 1, 1)) + c := &cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + InvertedGroups: map[string]struct{}{"nope": {}}, + Name: "good-host", + }, + } + for n := 0; n < b.N; n++ { + ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp) + } + }) + + ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "") + + b.Run("pass on ip with any port", func(b *testing.B) { + ip := ip2int(net.IPv4(172, 1, 1, 1)) + c := &cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + InvertedGroups: map[string]struct{}{"nope": {}}, + Name: "good-host", + }, + } + for n := 0; n < b.N; n++ { + ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp) + } + }) +} + +func TestFirewall_Drop2(t *testing.T) { + p := FirewallPacket{ + ip2int(net.IPv4(1, 2, 3, 4)), + 101, + 10, + 90, + fwProtoUDP, + false, + } + + ipNet := net.IPNet{ + IP: net.IPv4(1, 2, 3, 4), + Mask: net.IPMask{255, 255, 255, 0}, + } + + c := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "host1", + Ips: []*net.IPNet{&ipNet}, + InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}}, + }, + } + + c1 := cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + Name: "host1", + Ips: []*net.IPNet{&ipNet}, + InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}}, + }, + } + + fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) + assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", "")) + cp := cert.NewCAPool() + + // c1 lacks the proper groups + assert.True(t, fw.Drop([]byte{}, p, true, &c1, cp)) + // c has the proper groups + assert.False(t, fw.Drop([]byte{}, p, true, &c, cp)) +} + +func BenchmarkLookup(b *testing.B) { + ml := func(m map[string]struct{}, a [][]string) { + for n := 0; n < b.N; n++ { + for _, sg := range a { + found := false + + for _, g := range sg { + if _, ok := m[g]; !ok { + found = false + break + } + + found = true + } + + if found { + return + } + } + } + } + + b.Run("array to map best", func(b *testing.B) { + m := map[string]struct{}{ + "1ne": {}, + "2wo": {}, + "3hr": {}, + "4ou": {}, + "5iv": {}, + "6ix": {}, + } + + a := [][]string{ + {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"}, + {"one", "2wo", "3hr", "4ou", "5iv", "6ix"}, + {"one", "two", "3hr", "4ou", "5iv", "6ix"}, + {"one", "two", "thr", "4ou", "5iv", "6ix"}, + {"one", "two", "thr", "fou", "5iv", "6ix"}, + {"one", "two", "thr", "fou", "fiv", "6ix"}, + {"one", "two", "thr", "fou", "fiv", "six"}, + } + + for n := 0; n < b.N; n++ { + ml(m, a) + } + }) + + b.Run("array to map worst", func(b *testing.B) { + m := map[string]struct{}{ + "one": {}, + "two": {}, + "thr": {}, + "fou": {}, + "fiv": {}, + "six": {}, + } + + a := [][]string{ + {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"}, + {"one", "2wo", "3hr", "4ou", "5iv", "6ix"}, + {"one", "two", "3hr", "4ou", "5iv", "6ix"}, + {"one", "two", "thr", "4ou", "5iv", "6ix"}, + {"one", "two", "thr", "fou", "5iv", "6ix"}, + {"one", "two", "thr", "fou", "fiv", "6ix"}, + {"one", "two", "thr", "fou", "fiv", "six"}, + } + + for n := 0; n < b.N; n++ { + ml(m, a) + } + }) + + //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster +} + +func Test_parsePort(t *testing.T) { + _, _, err := parsePort("") + assert.EqualError(t, err, "was not a number; ``") + + _, _, err = parsePort(" ") + assert.EqualError(t, err, "was not a number; ` `") + + _, _, err = parsePort("-") + assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`") + + _, _, err = parsePort(" - ") + assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `") + + _, _, err = parsePort("a-b") + assert.EqualError(t, err, "beginning range was not a number; `a`") + + _, _, err = parsePort("1-b") + assert.EqualError(t, err, "ending range was not a number; `b`") + + s, e, err := parsePort(" 1 - 2 ") + assert.Equal(t, int32(1), s) + assert.Equal(t, int32(2), e) + assert.Nil(t, err) + + s, e, err = parsePort("0-1") + assert.Equal(t, int32(0), s) + assert.Equal(t, int32(0), e) + assert.Nil(t, err) + + s, e, err = parsePort("9919") + assert.Equal(t, int32(9919), s) + assert.Equal(t, int32(9919), e) + assert.Nil(t, err) + + s, e, err = parsePort("any") + assert.Equal(t, int32(0), s) + assert.Equal(t, int32(0), e) + assert.Nil(t, err) +} + +func TestNewFirewallFromConfig(t *testing.T) { + // Test a bad rule definition + c := &cert.NebulaCertificate{} + conf := NewConfig() + conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} + _, err := NewFirewallFromConfig(c, conf) + assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") + + // Test both port and code + conf = NewConfig() + conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} + _, err = NewFirewallFromConfig(c, conf) + assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") + + // Test missing host, group, cidr, ca_name and ca_sha + conf = NewConfig() + conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} + _, err = NewFirewallFromConfig(c, conf) + assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided") + + // Test code/port error + conf = NewConfig() + conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} + _, err = NewFirewallFromConfig(c, conf) + assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") + + conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} + _, err = NewFirewallFromConfig(c, conf) + assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") + + // Test proto error + conf = NewConfig() + conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} + _, err = NewFirewallFromConfig(c, conf) + assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") + + // Test cidr parse error + conf = NewConfig() + conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} + _, err = NewFirewallFromConfig(c, conf) + assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") + + // Test both group and groups + conf = NewConfig() + conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} + _, err = NewFirewallFromConfig(c, conf) + assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") +} + +func TestAddFirewallRulesFromConfig(t *testing.T) { + // Test adding tcp rule + conf := NewConfig() + mf := &mockFirewall{} + conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} + assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf)) + assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + + // Test adding udp rule + conf = NewConfig() + mf = &mockFirewall{} + conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} + assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf)) + assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + + // Test adding icmp rule + conf = NewConfig() + mf = &mockFirewall{} + conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} + assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf)) + assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + + // Test adding any rule + conf = NewConfig() + mf = &mockFirewall{} + conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} + assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + + // Test adding rule with ca_sha + conf = NewConfig() + mf = &mockFirewall{} + conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} + assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall) + + // Test adding rule with ca_name + conf = NewConfig() + mf = &mockFirewall{} + conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} + assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall) + + // Test single group + conf = NewConfig() + mf = &mockFirewall{} + conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} + assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) + + // Test single groups + conf = NewConfig() + mf = &mockFirewall{} + conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} + assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) + + // Test multiple AND groups + conf = NewConfig() + mf = &mockFirewall{} + conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} + assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall) + + // Test Add error + conf = NewConfig() + mf = &mockFirewall{} + mf.nextCallReturn = errors.New("test error") + conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} + assert.EqualError(t, AddFirewallRulesFromConfig(true, conf, mf), "firewall.inbound rule #0; `test error`") +} + +func TestTCPRTTTracking(t *testing.T) { + b := make([]byte, 200) + + // Max ip IHL (60 bytes) and tcp IHL (60 bytes) + b[0] = 15 + b[60+12] = 15 << 4 + f := Firewall{ + metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)), + } + + // Set SEQ to 1 + binary.BigEndian.PutUint32(b[60+4:60+8], 1) + + c := &conn{} + setTCPRTTTracking(c, b) + assert.Equal(t, uint32(1), c.Seq) + + // Bad ack - no ack flag + binary.BigEndian.PutUint32(b[60+8:60+12], 80) + assert.False(t, f.checkTCPRTT(c, b)) + + // Bad ack, number is too low + binary.BigEndian.PutUint32(b[60+8:60+12], 0) + b[60+13] = uint8(0x10) + assert.False(t, f.checkTCPRTT(c, b)) + + // Good ack + binary.BigEndian.PutUint32(b[60+8:60+12], 80) + assert.True(t, f.checkTCPRTT(c, b)) + assert.Equal(t, uint32(0), c.Seq) + + // Set SEQ to 1 + binary.BigEndian.PutUint32(b[60+4:60+8], 1) + c = &conn{} + setTCPRTTTracking(c, b) + assert.Equal(t, uint32(1), c.Seq) + + // Good acks + binary.BigEndian.PutUint32(b[60+8:60+12], 81) + assert.True(t, f.checkTCPRTT(c, b)) + assert.Equal(t, uint32(0), c.Seq) + + // Set SEQ to max uint32 - 20 + binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20) + c = &conn{} + setTCPRTTTracking(c, b) + assert.Equal(t, ^uint32(0)-20, c.Seq) + + // Good acks + binary.BigEndian.PutUint32(b[60+8:60+12], 81) + assert.True(t, f.checkTCPRTT(c, b)) + assert.Equal(t, uint32(0), c.Seq) + + // Set SEQ to max uint32 / 2 + binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2) + c = &conn{} + setTCPRTTTracking(c, b) + assert.Equal(t, ^uint32(0)/2, c.Seq) + + // Below + binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1) + assert.False(t, f.checkTCPRTT(c, b)) + assert.Equal(t, ^uint32(0)/2, c.Seq) + + // Halfway below + binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0)) + assert.False(t, f.checkTCPRTT(c, b)) + assert.Equal(t, ^uint32(0)/2, c.Seq) + + // Halfway above is ok + binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)) + assert.True(t, f.checkTCPRTT(c, b)) + assert.Equal(t, uint32(0), c.Seq) + + // Set SEQ to max uint32 + binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)) + c = &conn{} + setTCPRTTTracking(c, b) + assert.Equal(t, ^uint32(0), c.Seq) + + // Halfway + 1 above + binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1) + assert.False(t, f.checkTCPRTT(c, b)) + assert.Equal(t, ^uint32(0), c.Seq) + + // Halfway above + binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2) + assert.True(t, f.checkTCPRTT(c, b)) + assert.Equal(t, uint32(0), c.Seq) +} + +type addRuleCall struct { + incoming bool + proto uint8 + startPort int32 + endPort int32 + groups []string + host string + ip *net.IPNet + caName string + caSha string +} + +type mockFirewall struct { + lastCall addRuleCall + nextCallReturn error +} + +func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { + mf.lastCall = addRuleCall{ + incoming: incoming, + proto: proto, + startPort: startPort, + endPort: endPort, + groups: groups, + host: host, + ip: ip, + caName: caName, + caSha: caSha, + } + + err := mf.nextCallReturn + mf.nextCallReturn = nil + return err +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..5a44be8 --- /dev/null +++ b/go.mod @@ -0,0 +1,32 @@ +module github.com/slackhq/nebula + +go 1.12 + +require ( + github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239 + github.com/armon/go-radix v1.0.0 + github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 + github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect + github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6 + github.com/golang/protobuf v1.3.1 + github.com/imdario/mergo v0.3.7 + github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect + github.com/kr/pretty v0.1.0 // indirect + github.com/miekg/dns v1.1.12 + github.com/nbrownus/go-metrics-prometheus v0.0.0-20180622211546-6e6d5173d99c + github.com/prometheus/client_golang v0.9.3 + github.com/prometheus/common v0.4.1 // indirect + github.com/prometheus/procfs v0.0.0-20190523193104-a7aeb8df3389 // indirect + github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a + github.com/sirupsen/logrus v1.4.2 + github.com/songgao/water v0.0.0-20190402020555-6ad6edefb15c + github.com/stretchr/testify v1.3.0 + github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a + github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc // indirect + golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f + golang.org/x/net v0.0.0-20190522155817-f3200d17e092 + golang.org/x/sync v0.0.0-20190423024810-112230192c58 // indirect + golang.org/x/sys v0.0.0-20190524152521-dbbf3f1254d4 + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gopkg.in/yaml.v2 v2.2.2 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..376506f --- /dev/null +++ b/go.sum @@ -0,0 +1,112 @@ +github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239 h1:kFOfPq6dUM1hTo4JG6LR5AXSUEsOjtdm0kw0FtQtMJA= +github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI= +github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= +github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps= +github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= +github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6 h1:u/UEqS66A5ckRmS4yNpjmVH56sVtS/RfclBAYocb4as= +github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6/go.mod h1:1i71OnUq3iUe1ma7Lr6yG6/rjvM3emb6yoL7xLFzcVQ= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/imdario/mergo v0.3.7 h1:Y+UAYTZ7gDEuOfhxKWy+dvb5dRQ6rJjFSdX2HZY1/gI= +github.com/imdario/mergo v0.3.7/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/miekg/dns v1.1.12 h1:WMhc1ik4LNkTg8U9l3hI1LvxKmIL+f1+WV/SZtCbDDA= +github.com/miekg/dns v1.1.12/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/nbrownus/go-metrics-prometheus v0.0.0-20180622211546-6e6d5173d99c h1:G/mfx/MWYuaaGlHkZQBBXFAJiYnRt/GaOVxnRHjlxg4= +github.com/nbrownus/go-metrics-prometheus v0.0.0-20180622211546-6e6d5173d99c/go.mod h1:1yMri853KAI2pPAUnESjaqZj9JeImOUM+6A4GuuPmTs= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v0.9.3 h1:9iH4JKXLzFbOAdtqv/a+j8aewx2Y8lAjAydhbaScPF8= +github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 h1:S/YWwWx/RA8rT8tKFRuGUZhuA90OyIBpPCXkcbwU8DE= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= +github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.4.1 h1:K0MGApIoQvMw27RTdJkPbr3JZ7DNbtxQNyi5STVM6Kw= +github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.0.0-20190523193104-a7aeb8df3389 h1:F/k2nob1S9M6v5Xkq7KjSTQirOYaYQord0jR4TwyVmY= +github.com/prometheus/procfs v0.0.0-20190523193104-a7aeb8df3389/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= +github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a h1:9ZKAASQSHhDYGoxY8uLVpewe1GDZ2vu2Tr/vTdVAkFQ= +github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/songgao/water v0.0.0-20190402020555-6ad6edefb15c h1:VZsL/fAl8XnHj5Zn+cRvLcFbMHmCj7tdPrkKZSRziJ0= +github.com/songgao/water v0.0.0-20190402020555-6ad6edefb15c/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao= +github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= +github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc h1:R83G5ikgLMxrBvLh22JhdfI8K6YXEPHx5P03Uu3DRs4= +github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f h1:R423Cnkcp5JABoeemiGEPlt9tHXFfw5kvc0yqlxRPWo= +golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190522155817-f3200d17e092 h1:4QSRKanuywn15aTZvI/mIDEgPQpswuFndXpOj3rKEco= +golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190524152521-dbbf3f1254d4 h1:VSJ45BzqrVgR4clSx415y1rHH7QAGhGt71J0ZmhLYrc= +golang.org/x/sys v0.0.0-20190524152521-dbbf3f1254d4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/handshake.go b/handshake.go new file mode 100644 index 0000000..9836f83 --- /dev/null +++ b/handshake.go @@ -0,0 +1,82 @@ +package nebula + +import ( + "crypto/hmac" + "crypto/sha256" + "errors" + "github.com/golang/protobuf/proto" +) + +const ( + handshakeIXPSK0 = 0 + handshakeXXPSK0 = 1 +) + +func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) { + newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex) + //TODO: For stage 1 we won't have hostinfo yet but stage 2 and above would require it, this check may be helpful in those cases + //if err != nil { + // l.WithError(err).WithField("udpAddr", addr).Error("Error while finding host info for handshake message") + // return + //} + + tearDown := false + switch h.Subtype { + case handshakeIXPSK0: + switch h.MessageCounter { + case 1: + tearDown = ixHandshakeStage1(f, addr, newHostinfo, packet, h) + case 2: + tearDown = ixHandshakeStage2(f, addr, newHostinfo, packet, h) + } + } + + if tearDown && newHostinfo != nil { + f.handshakeManager.DeleteIndex(newHostinfo.localIndexId) + f.handshakeManager.DeleteVpnIP(newHostinfo.hostId) + } +} + +func HandshakeBytesWithMAC(details *NebulaHandshakeDetails, key []byte) ([]byte, error) { + mac := hmac.New(sha256.New, key) + + b, err := proto.Marshal(details) + if err != nil { + return nil, errors.New("Unable to marshal nebula handshake") + } + mac.Write(b) + sum := mac.Sum(nil) + + hs := &NebulaHandshake{ + Details: details, + Hmac: sum, + } + + hsBytes, err := proto.Marshal(hs) + if err != nil { + l.Debugln("failed to generate NebulaHandshake protobuf", err) + } + + return hsBytes, nil +} + +func (hs *NebulaHandshake) CheckHandshakeMAC(keys [][]byte) bool { + + b, err := proto.Marshal(hs.Details) + if err != nil { + return false + } + + for _, k := range keys { + mac := hmac.New(sha256.New, k) + mac.Write(b) + expectedMAC := mac.Sum(nil) + if hmac.Equal(hs.Hmac, expectedMAC) { + return true + } + } + + //l.Debugln(hs.Hmac, expectedMAC) + + return false +} diff --git a/handshake_ix.go b/handshake_ix.go new file mode 100644 index 0000000..2eaabc3 --- /dev/null +++ b/handshake_ix.go @@ -0,0 +1,356 @@ +package nebula + +import ( + "sync/atomic" + "time" + + "bytes" + + "github.com/flynn/noise" + "github.com/golang/protobuf/proto" +) + +// NOISE IX Handshakes + +// This function constructs a handshake packet, but does not actually send it +// Sending is done by the handshake manager +func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) { + // This queries the lighthouse if we don't know a remote for the host + if hostinfo.remote == nil { + ips, err := f.lightHouse.Query(vpnIp, f) + if err != nil { + //l.Debugln(err) + } + for _, ip := range ips { + hostinfo.AddRemote(ip) + } + } + + myIndex, err := generateIndex() + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") + return + } + ci := hostinfo.ConnectionState + f.handshakeManager.AddIndexHostInfo(myIndex, hostinfo) + + hsProto := &NebulaHandshakeDetails{ + InitiatorIndex: myIndex, + Time: uint64(time.Now().Unix()), + Cert: ci.certState.rawCertificateNoKey, + } + + hs := &NebulaHandshake{ + Details: hsProto, + Hmac: nil, + } + + hsBytes, err := proto.Marshal(hs) + //hsBytes, err := HandshakeBytesWithMAC(hsProto, f.handshakeMACKey) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") + return + } + + header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, 0, 1) + atomic.AddUint64(ci.messageCounter, 1) + + msg, _, _, err := ci.H.WriteMessage(header, hsBytes) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") + return + } + + hostinfo.HandshakePacket[0] = msg + hostinfo.HandshakeReady = true + hostinfo.handshakeStart = time.Now() + /* + l.Debugln("ZZZZZZZZZZREMOTE: ", hostinfo.remote) + */ + +} + +func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool { + var ip uint32 + if h.RemoteIndex == 0 { + ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0) + // Mark packet 1 as seen so it doesn't show up as missed + ci.window.Update(1) + + msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) + if err != nil { + l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") + return true + } + + hs := &NebulaHandshake{} + err = proto.Unmarshal(msg, hs) + /* + l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex) + */ + if err != nil || hs.Details == nil { + l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") + return true + } + + hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex) + if hostinfo != nil && bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) { + if msg, ok := hostinfo.HandshakePacket[2]; ok { + err := f.outside.WriteTo(msg, addr) + if err != nil { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). + WithError(err).Error("Failed to send handshake message") + } else { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). + Info("Handshake message sent") + } + return false + } + + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cached", true). + WithField("packets", hostinfo.HandshakePacket). + Error("Seen this handshake packet already but don't have a cached packet to return") + } + + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) + if err != nil { + l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). + Info("Invalid certificate from host") + return true + } + vpnIP := ip2int(remoteCert.Details.Ips[0].IP) + + myIndex, err := generateIndex() + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") + return true + } + + hostinfo, err = f.handshakeManager.AddIndex(myIndex, ci) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager") + + return true + } + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Info("Handshake message received") + + hostinfo.remoteIndexId = hs.Details.InitiatorIndex + hs.Details.ResponderIndex = myIndex + hs.Details.Cert = ci.certState.rawCertificateNoKey + hs.Hmac = nil + + hsBytes, err := proto.Marshal(hs) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") + return true + } + + header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2) + msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") + return true + } + + if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) { + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Info("Prevented a handshake race") + + // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues + f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + return true + } + + hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:])) + copy(hostinfo.HandshakePacket[0], packet[HeaderLen:]) + + // Regardless of whether you are the sender or receiver, you should arrive here + // and complete standing up the connection. + if dKey != nil && eKey != nil { + hostinfo.HandshakePacket[2] = make([]byte, len(msg)) + copy(hostinfo.HandshakePacket[2], msg) + + err := f.outside.WriteTo(msg, addr) + if err != nil { + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + WithError(err).Error("Failed to send handshake") + } else { + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Info("Handshake message sent") + } + + ip = ip2int(remoteCert.Details.Ips[0].IP) + ci.peerCert = remoteCert + ci.dKey = NewNebulaCipherState(dKey) + ci.eKey = NewNebulaCipherState(eKey) + //l.Debugln("got symmetric pairs") + + //hostinfo.ClearRemotes() + hostinfo.AddRemote(*addr) + f.lightHouse.AddRemoteAndReset(ip, addr) + if f.serveDns { + dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) + } + + ho, err := f.hostMap.QueryVpnIP(vpnIP) + if err == nil && ho.localIndexId != 0 { + l.WithField("vpnIp", vpnIP). + WithField("action", "removing stale index"). + WithField("index", ho.localIndexId). + Debug("Handshake processing") + f.hostMap.DeleteIndex(ho.localIndexId) + } + + f.hostMap.AddIndexHostInfo(hostinfo.localIndexId, hostinfo) + f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo) + + hostinfo.handshakeComplete() + } else { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Noise did not arrive at a key") + return true + } + + } + + f.hostMap.AddRemote(ip, addr) + /* + l.Debugln("111 ZZZZZZZZZZADDR: ", addr) + l.Debugln("111 ZZZZZZZZZZREMOTE: ", hostinfo.remote) + l.Debugln("111 ZZZZZZZZZZREMOTEs: ", hostinfo.Remotes[0].addr) + */ + return false +} + +func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool { + if hostinfo == nil { + return true + } + + if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). + Error("Already seen this handshake packet") + return false + } + + ci := hostinfo.ConnectionState + // Mark packet 2 as seen so it doesn't show up as missed + ci.window.Update(2) + + hostinfo.HandshakePacket[2] = make([]byte, len(packet[HeaderLen:])) + copy(hostinfo.HandshakePacket[2], packet[HeaderLen:]) + + msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). + Error("Failed to call noise.ReadMessage") + + // We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying + // to DOS us. Every other error condition after should to allow a possible good handshake to complete in the + // near future + return false + } + + hs := &NebulaHandshake{} + err = proto.Unmarshal(msg, hs) + if err != nil || hs.Details == nil { + l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") + return true + } + + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Error("Invalid certificate from host") + return true + } + vpnIP := ip2int(remoteCert.Details.Ips[0].IP) + + duration := time.Since(hostinfo.handshakeStart).Nanoseconds() + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + WithField("durationNs", duration). + Info("Handshake message received") + + //ci.remoteIndex = hs.ResponderIndex + hostinfo.remoteIndexId = hs.Details.ResponderIndex + hs.Details.Cert = ci.certState.rawCertificateNoKey + + /* + hsBytes, err := proto.Marshal(hs) + if err != nil { + l.Debugln("Failed to marshal handshake: ", err) + return + } + */ + + // Regardless of whether you are the sender or receiver, you should arrive here + // and complete standing up the connection. + if dKey != nil && eKey != nil { + ip := ip2int(remoteCert.Details.Ips[0].IP) + ci.peerCert = remoteCert + ci.dKey = NewNebulaCipherState(dKey) + ci.eKey = NewNebulaCipherState(eKey) + //l.Debugln("got symmetric pairs") + + //hostinfo.ClearRemotes() + f.hostMap.AddRemote(ip, addr) + f.lightHouse.AddRemoteAndReset(ip, addr) + if f.serveDns { + dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) + } + + ho, err := f.hostMap.QueryVpnIP(vpnIP) + if err == nil && ho.localIndexId != 0 { + l.WithField("vpnIp", vpnIP). + WithField("action", "removing stale index"). + WithField("index", ho.localIndexId). + Debug("Handshake processing") + f.hostMap.DeleteIndex(ho.localIndexId) + } + + f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo) + f.hostMap.AddIndexHostInfo(hostinfo.localIndexId, hostinfo) + + hostinfo.handshakeComplete() + f.metricHandshakes.Update(duration) + } else { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Error("Noise did not arrive at a key") + return true + } + + return false + /* + l.Debugln("222 ZZZZZZZZZZREMOTE: ", hostinfo.remote) + */ +} diff --git a/handshake_manager.go b/handshake_manager.go new file mode 100644 index 0000000..ec007a0 --- /dev/null +++ b/handshake_manager.go @@ -0,0 +1,200 @@ +package nebula + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + "net" + "time" + + "github.com/sirupsen/logrus" +) + +const ( + // Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries + // With 100ms interval and 20 retries is 23.5 seconds + HandshakeTryInterval = time.Millisecond * 100 + HandshakeRetries = 20 + // HandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses + HandshakeWaitRotation = 5 +) + +type HandshakeManager struct { + pendingHostMap *HostMap + mainHostMap *HostMap + lightHouse *LightHouse + outside *udpConn + + OutboundHandshakeTimer *SystemTimerWheel + InboundHandshakeTimer *SystemTimerWheel +} + +func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn) *HandshakeManager { + return &HandshakeManager{ + pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges), + mainHostMap: mainHostMap, + lightHouse: lightHouse, + outside: outside, + + OutboundHandshakeTimer: NewSystemTimerWheel(HandshakeTryInterval, HandshakeTryInterval*HandshakeRetries), + InboundHandshakeTimer: NewSystemTimerWheel(HandshakeTryInterval, HandshakeTryInterval*HandshakeRetries), + } +} + +func (c *HandshakeManager) Run(f EncWriter) { + clockSource := time.Tick(HandshakeTryInterval) + for now := range clockSource { + c.NextOutboundHandshakeTimerTick(now, f) + c.NextInboundHandshakeTimerTick(now) + } +} + +func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) { + c.OutboundHandshakeTimer.advance(now) + for { + ep := c.OutboundHandshakeTimer.Purge() + if ep == nil { + break + } + vpnIP := ep.(uint32) + + index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP) + if err != nil { + continue + } + + hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP) + if err != nil { + continue + } + + // If we haven't finished the handshake and we haven't hit max retries, query + // lighthouse and then send the handshake packet again. + if hostinfo.HandshakeCounter < HandshakeRetries && !hostinfo.HandshakeComplete { + if hostinfo.remote == nil { + // We continue to query the lighthouse because hosts may + // come online during handshake retries. If the query + // succeeds (no error), add the lighthouse info to hostinfo + ips, err := c.lightHouse.Query(vpnIP, f) + if err == nil { + for _, ip := range ips { + hostinfo.AddRemote(ip) + } + hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges) + } + } + + hostinfo.HandshakeCounter++ + + // We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through + // all the others until we can stand up a connection. + if hostinfo.HandshakeCounter > HandshakeWaitRotation { + hostinfo.rotateRemote() + } + + // Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation + if hostinfo.HandshakeReady && hostinfo.remote != nil { + err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote) + if err != nil { + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", hostinfo.remote). + WithField("initiatorIndex", hostinfo.localIndexId). + WithField("remoteIndex", hostinfo.remoteIndexId). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + WithError(err).Error("Failed to send handshake message") + } else { + //TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should + // keep the real packet struct around for logging purposes + l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", hostinfo.remote). + WithField("initiatorIndex", hostinfo.localIndexId). + WithField("remoteIndex", hostinfo.remoteIndexId). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Info("Handshake message sent") + } + } + + // Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try + //l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter)) + c.OutboundHandshakeTimer.Add(vpnIP, HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter)) + } else { + c.pendingHostMap.DeleteVpnIP(vpnIP) + c.pendingHostMap.DeleteIndex(index) + } + } +} + +func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) { + c.InboundHandshakeTimer.advance(now) + for { + ep := c.InboundHandshakeTimer.Purge() + if ep == nil { + break + } + index := ep.(uint32) + + vpnIP, err := c.pendingHostMap.GetVpnIPByIndex(index) + if err != nil { + continue + } + c.pendingHostMap.DeleteIndex(index) + c.pendingHostMap.DeleteVpnIP(vpnIP) + } +} + +func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo { + hostinfo := c.pendingHostMap.AddVpnIP(vpnIP) + // We lock here and use an array to insert items to prevent locking the + // main receive thread for very long by waiting to add items to the pending map + c.OutboundHandshakeTimer.Add(vpnIP, HandshakeTryInterval) + return hostinfo +} + +func (c *HandshakeManager) DeleteVpnIP(vpnIP uint32) { + //l.Debugln("Deleting pending vpn ip :", IntIp(vpnIP)) + c.pendingHostMap.DeleteVpnIP(vpnIP) +} + +func (c *HandshakeManager) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) { + hostinfo, err := c.pendingHostMap.AddIndex(index, ci) + if err != nil { + return nil, fmt.Errorf("Issue adding index: %d", index) + } + //c.mainHostMap.AddIndexHostInfo(index, hostinfo) + c.InboundHandshakeTimer.Add(index, time.Second*10) + return hostinfo, nil +} + +func (c *HandshakeManager) AddIndexHostInfo(index uint32, h *HostInfo) { + c.pendingHostMap.AddIndexHostInfo(index, h) +} + +func (c *HandshakeManager) DeleteIndex(index uint32) { + //l.Debugln("Deleting pending index :", index) + c.pendingHostMap.DeleteIndex(index) +} + +func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) { + return c.pendingHostMap.QueryIndex(index) +} + +func (c *HandshakeManager) EmitStats() { + c.pendingHostMap.EmitStats("pending") + c.mainHostMap.EmitStats("main") +} + +// Utility functions below + +func generateIndex() (uint32, error) { + b := make([]byte, 4) + _, err := rand.Read(b) + if err != nil { + l.Errorln(err) + return 0, err + } + + index := binary.BigEndian.Uint32(b) + if l.Level >= logrus.DebugLevel { + l.WithField("index", index). + Debug("Generated index") + } + return index, nil +} diff --git a/handshake_manager_test.go b/handshake_manager_test.go new file mode 100644 index 0000000..6822b7c --- /dev/null +++ b/handshake_manager_test.go @@ -0,0 +1,191 @@ +package nebula + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var indexes []uint32 = []uint32{1000, 2000, 3000, 4000} + +//var ips []uint32 = []uint32{9000, 9999999, 3, 292394923} +var ips []uint32 = []uint32{9000} + +func Test_NewHandshakeManagerIndex(t *testing.T) { + _, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") + _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") + _, localrange, _ := net.ParseCIDR("10.1.1.1/24") + preferredRanges := []*net.IPNet{localrange} + mainHM := NewHostMap("test", vpncidr, preferredRanges) + + blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}) + + now := time.Now() + blah.NextInboundHandshakeTimerTick(now) + + // Add four indexes + for _, v := range indexes { + blah.AddIndex(v, &ConnectionState{}) + } + // Confirm they are in the pending index list + for _, v := range indexes { + assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v)) + } + // Adding something to pending should not affect the main hostmap + assert.Len(t, mainHM.Indexes, 0) + // Jump ahead 8 seconds + for i := 1; i <= HandshakeRetries; i++ { + next_tick := now.Add(HandshakeTryInterval * time.Duration(i)) + blah.NextInboundHandshakeTimerTick(next_tick) + } + // Confirm they are still in the pending index list + for _, v := range indexes { + assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v)) + } + // Jump ahead 4 more seconds + next_tick := now.Add(12 * time.Second) + blah.NextInboundHandshakeTimerTick(next_tick) + // Confirm they have been removed + for _, v := range indexes { + assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(v)) + } +} + +func Test_NewHandshakeManagerVpnIP(t *testing.T) { + _, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") + _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") + _, localrange, _ := net.ParseCIDR("10.1.1.1/24") + preferredRanges := []*net.IPNet{localrange} + mw := &mockEncWriter{} + mainHM := NewHostMap("test", vpncidr, preferredRanges) + + blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}) + + now := time.Now() + blah.NextOutboundHandshakeTimerTick(now, mw) + + // Add four "IPs" - which are just uint32s + for _, v := range ips { + blah.AddVpnIP(v) + } + // Adding something to pending should not affect the main hostmap + assert.Len(t, mainHM.Hosts, 0) + // Confirm they are in the pending index list + for _, v := range ips { + assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v)) + } + + // Jump ahead `HandshakeRetries` ticks + cumulative := time.Duration(0) + for i := 0; i <= HandshakeRetries+1; i++ { + cumulative += time.Duration(i)*HandshakeTryInterval + 1 + next_tick := now.Add(cumulative) + //l.Infoln(next_tick) + blah.NextOutboundHandshakeTimerTick(next_tick, mw) + } + + // Confirm they are still in the pending index list + for _, v := range ips { + assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v)) + } + // Jump ahead 1 more second + cumulative += time.Duration(HandshakeRetries+1) * HandshakeTryInterval + next_tick := now.Add(cumulative) + //l.Infoln(next_tick) + blah.NextOutboundHandshakeTimerTick(next_tick, mw) + // Confirm they have been removed + for _, v := range ips { + assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(v)) + } +} + +func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) { + _, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") + _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") + _, localrange, _ := net.ParseCIDR("10.1.1.1/24") + preferredRanges := []*net.IPNet{localrange} + mw := &mockEncWriter{} + mainHM := NewHostMap("test", vpncidr, preferredRanges) + + blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}) + + now := time.Now() + blah.NextOutboundHandshakeTimerTick(now, mw) + + hostinfo := blah.AddVpnIP(101010) + // Pretned we have an index too + blah.AddIndexHostInfo(12341234, hostinfo) + assert.Contains(t, blah.pendingHostMap.Indexes, uint32(12341234)) + + // Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending + // but not main hostmap + cumulative := time.Duration(0) + for i := 1; i <= HandshakeRetries+2; i++ { + cumulative += HandshakeTryInterval * time.Duration(i) + next_tick := now.Add(cumulative) + blah.NextOutboundHandshakeTimerTick(next_tick, mw) + } + /* + for i := 0; i <= HandshakeRetries+1; i++ { + next_tick := now.Add(cumulative) + //l.Infoln(next_tick) + blah.NextOutboundHandshakeTimerTick(next_tick) + } + */ + /* + for i := 0; i <= HandshakeRetries+1; i++ { + next_tick := now.Add(time.Duration(i) * time.Second) + blah.NextOutboundHandshakeTimerTick(next_tick) + } + */ + + /* + cumulative += HandshakeTryInterval*time.Duration(HandshakeRetries) + 3 + next_tick := now.Add(cumulative) + l.Infoln(cumulative, next_tick) + blah.NextOutboundHandshakeTimerTick(next_tick) + */ + assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010)) + assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234)) +} + +func Test_NewHandshakeManagerIndexcleanup(t *testing.T) { + _, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") + _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") + _, localrange, _ := net.ParseCIDR("10.1.1.1/24") + preferredRanges := []*net.IPNet{localrange} + mainHM := NewHostMap("test", vpncidr, preferredRanges) + + blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}) + + now := time.Now() + blah.NextInboundHandshakeTimerTick(now) + + hostinfo, _ := blah.AddIndex(12341234, &ConnectionState{}) + // Pretned we have an index too + blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo) + assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010)) + + for i := 1; i <= HandshakeRetries+2; i++ { + next_tick := now.Add(HandshakeTryInterval * time.Duration(i)) + blah.NextInboundHandshakeTimerTick(next_tick) + } + + next_tick := now.Add(HandshakeTryInterval*HandshakeRetries + 3) + blah.NextInboundHandshakeTimerTick(next_tick) + assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010)) + assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234)) +} + +type mockEncWriter struct { +} + +func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { + return +} + +func (mw *mockEncWriter) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { + return +} diff --git a/header.go b/header.go new file mode 100644 index 0000000..3f15faa --- /dev/null +++ b/header.go @@ -0,0 +1,188 @@ +package nebula + +import ( + "encoding/binary" + "encoding/json" + "errors" + "fmt" +) + +//Version 1 header: +// 0 31 +// |-----------------------------------------------------------------------| +// | Version (uint4) | Type (uint4) | Subtype (uint8) | Reserved (uint16) | 32 +// |-----------------------------------------------------------------------| +// | Remote index (uint32) | 64 +// |-----------------------------------------------------------------------| +// | Message counter | 96 +// | (uint64) | 128 +// |-----------------------------------------------------------------------| +// | payload... | + +const ( + Version uint8 = 1 + HeaderLen = 16 +) + +type NebulaMessageType uint8 +type NebulaMessageSubType uint8 + +const ( + handshake NebulaMessageType = 0 + message NebulaMessageType = 1 + recvError NebulaMessageType = 2 + lightHouse NebulaMessageType = 3 + test NebulaMessageType = 4 + closeTunnel NebulaMessageType = 5 + + //TODO These are deprecated as of 06/12/2018 - NB + testRemote NebulaMessageType = 6 + testRemoteReply NebulaMessageType = 7 +) + +var typeMap = map[NebulaMessageType]string{ + handshake: "handshake", + message: "message", + recvError: "recvError", + lightHouse: "lightHouse", + test: "test", + closeTunnel: "closeTunnel", + + //TODO These are deprecated as of 06/12/2018 - NB + testRemote: "testRemote", + testRemoteReply: "testRemoteReply", +} + +const ( + testRequest NebulaMessageSubType = 0 + testReply NebulaMessageSubType = 1 +) + +var eHeaderTooShort = errors.New("header is too short") + +var subTypeTestMap = map[NebulaMessageSubType]string{ + testRequest: "testRequest", + testReply: "testReply", +} + +var subTypeNoneMap = map[NebulaMessageSubType]string{0: "none"} + +var subTypeMap = map[NebulaMessageType]*map[NebulaMessageSubType]string{ + message: &subTypeNoneMap, + recvError: &subTypeNoneMap, + lightHouse: &subTypeNoneMap, + test: &subTypeTestMap, + closeTunnel: &subTypeNoneMap, + handshake: { + handshakeIXPSK0: "ix_psk0", + }, + //TODO: these are deprecated + testRemote: &subTypeNoneMap, + testRemoteReply: &subTypeNoneMap, +} + +type Header struct { + Version uint8 + Type NebulaMessageType + Subtype NebulaMessageSubType + Reserved uint16 + RemoteIndex uint32 + MessageCounter uint64 +} + +// HeaderEncode uses the provided byte array to encode the provided header values into. +// Byte array must be capped higher than HeaderLen or this will panic +func HeaderEncode(b []byte, v uint8, t uint8, st uint8, ri uint32, c uint64) []byte { + b = b[:HeaderLen] + b[0] = byte(v<<4 | (t & 0x0f)) + b[1] = byte(st) + binary.BigEndian.PutUint16(b[2:4], 0) + binary.BigEndian.PutUint32(b[4:8], ri) + binary.BigEndian.PutUint64(b[8:16], c) + return b +} + +// String creates a readable string representation of a header +func (h *Header) String() string { + if h == nil { + return "" + } + return fmt.Sprintf("ver=%d type=%s subtype=%s reserved=%#x remoteindex=%v messagecounter=%v", + h.Version, h.TypeName(), h.SubTypeName(), h.Reserved, h.RemoteIndex, h.MessageCounter) +} + +// MarshalJSON creates a json string representation of a header +func (h *Header) MarshalJSON() ([]byte, error) { + return json.Marshal(m{ + "version": h.Version, + "type": h.TypeName(), + "subType": h.SubTypeName(), + "reserved": h.Reserved, + "remoteIndex": h.RemoteIndex, + "messageCounter": h.MessageCounter, + }) +} + +// Encode turns header into bytes +func (h *Header) Encode(b []byte) ([]byte, error) { + if h == nil { + return nil, errors.New("nil header") + } + + return HeaderEncode(b, h.Version, uint8(h.Type), uint8(h.Subtype), h.RemoteIndex, h.MessageCounter), nil +} + +// Parse is a helper function to parses given bytes into new Header struct +func (h *Header) Parse(b []byte) error { + if len(b) < HeaderLen { + return eHeaderTooShort + } + // get upper 4 bytes + h.Version = uint8((b[0] >> 4) & 0x0f) + // get lower 4 bytes + h.Type = NebulaMessageType(b[0] & 0x0f) + h.Subtype = NebulaMessageSubType(b[1]) + h.Reserved = binary.BigEndian.Uint16(b[2:4]) + h.RemoteIndex = binary.BigEndian.Uint32(b[4:8]) + h.MessageCounter = binary.BigEndian.Uint64(b[8:16]) + return nil +} + +// TypeName will transform the headers message type into a human string +func (h *Header) TypeName() string { + return TypeName(h.Type) +} + +// TypeName will transform a nebula message type into a human string +func TypeName(t NebulaMessageType) string { + if n, ok := typeMap[t]; ok { + return n + } + + return "unknown" +} + +// SubTypeName will transform the headers message sub type into a human string +func (h *Header) SubTypeName() string { + return SubTypeName(h.Type, h.Subtype) +} + +// SubTypeName will transform a nebula message sub type into a human string +func SubTypeName(t NebulaMessageType, s NebulaMessageSubType) string { + if n, ok := subTypeMap[t]; ok { + if x, ok := (*n)[s]; ok { + return x + } + } + + return "unknown" +} + +// NewHeader turns bytes into a header +func NewHeader(b []byte) (*Header, error) { + h := new(Header) + if err := h.Parse(b); err != nil { + return nil, err + } + return h, nil +} diff --git a/header_test.go b/header_test.go new file mode 100644 index 0000000..02f71e0 --- /dev/null +++ b/header_test.go @@ -0,0 +1,118 @@ +package nebula + +import ( + "github.com/stretchr/testify/assert" + "reflect" + "testing" +) + +type headerTest struct { + expectedBytes []byte + *Header +} + +// 0001 0010 00010010 +var headerBigEndianTests = []headerTest{{ + expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9}, + // 1010 0000 + Header: &Header{ + // 1111 1+2+4+8 = 15 + Version: 5, + Type: 4, + Subtype: 0, + Reserved: 0, + RemoteIndex: 10, + MessageCounter: 9, + }, +}, +} + +func TestEncode(t *testing.T) { + for _, tt := range headerBigEndianTests { + b, err := tt.Encode(make([]byte, HeaderLen)) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, tt.expectedBytes, b) + } +} + +func TestParse(t *testing.T) { + for _, tt := range headerBigEndianTests { + b := tt.expectedBytes + parsedHeader := &Header{} + parsedHeader.Parse(b) + + if !reflect.DeepEqual(tt.Header, parsedHeader) { + t.Fatalf("got %#v; want %#v", parsedHeader, tt.Header) + } + } +} + +func TestTypeName(t *testing.T) { + assert.Equal(t, "test", TypeName(test)) + assert.Equal(t, "test", (&Header{Type: test}).TypeName()) + + assert.Equal(t, "unknown", TypeName(99)) + assert.Equal(t, "unknown", (&Header{Type: 99}).TypeName()) +} + +func TestSubTypeName(t *testing.T) { + assert.Equal(t, "testRequest", SubTypeName(test, testRequest)) + assert.Equal(t, "testRequest", (&Header{Type: test, Subtype: testRequest}).SubTypeName()) + + assert.Equal(t, "unknown", SubTypeName(99, testRequest)) + assert.Equal(t, "unknown", (&Header{Type: 99, Subtype: testRequest}).SubTypeName()) + + assert.Equal(t, "unknown", SubTypeName(test, 99)) + assert.Equal(t, "unknown", (&Header{Type: test, Subtype: 99}).SubTypeName()) + + assert.Equal(t, "none", SubTypeName(message, 0)) + assert.Equal(t, "none", (&Header{Type: message, Subtype: 0}).SubTypeName()) +} + +func TestTypeMap(t *testing.T) { + // Force people to document this stuff + assert.Equal(t, map[NebulaMessageType]string{ + handshake: "handshake", + message: "message", + recvError: "recvError", + lightHouse: "lightHouse", + test: "test", + closeTunnel: "closeTunnel", + testRemote: "testRemote", + testRemoteReply: "testRemoteReply", + }, typeMap) + + assert.Equal(t, map[NebulaMessageType]*map[NebulaMessageSubType]string{ + message: &subTypeNoneMap, + recvError: &subTypeNoneMap, + lightHouse: &subTypeNoneMap, + test: &subTypeTestMap, + closeTunnel: &subTypeNoneMap, + handshake: { + handshakeIXPSK0: "ix_psk0", + }, + testRemote: &subTypeNoneMap, + testRemoteReply: &subTypeNoneMap, + }, subTypeMap) +} + +func TestHeader_String(t *testing.T) { + assert.Equal( + t, + "ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97", + (&Header{100, test, testRequest, 99, 98, 97}).String(), + ) +} + +func TestHeader_MarshalJSON(t *testing.T) { + b, err := (&Header{100, test, testRequest, 99, 98, 97}).MarshalJSON() + assert.Nil(t, err) + assert.Equal( + t, + "{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}", + string(b), + ) +} diff --git a/hostmap.go b/hostmap.go new file mode 100644 index 0000000..20f8ce5 --- /dev/null +++ b/hostmap.go @@ -0,0 +1,743 @@ +package nebula + +import ( + "encoding/json" + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/rcrowley/go-metrics" + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" +) + +//const ProbeLen = 100 +const PromoteEvery = 1000 +const MaxRemotes = 10 + +// How long we should prevent roaming back to the previous IP. +// This helps prevent flapping due to packets already in flight +const RoamingSupressSeconds = 2 + +type HostMap struct { + sync.RWMutex //Because we concurrently read and write to our maps + name string + Indexes map[uint32]*HostInfo + Hosts map[uint32]*HostInfo + preferredRanges []*net.IPNet + vpnCIDR *net.IPNet + defaultRoute uint32 +} + +type HostInfo struct { + remote *udpAddr + Remotes []*HostInfoDest + promoteCounter uint32 + ConnectionState *ConnectionState + handshakeStart time.Time + HandshakeReady bool + HandshakeCounter int + HandshakeComplete bool + HandshakePacket map[uint8][]byte + packetStore []*cachedPacket + remoteIndexId uint32 + localIndexId uint32 + hostId uint32 + recvError int + + lastRoam time.Time + lastRoamRemote *udpAddr +} + +type cachedPacket struct { + messageType NebulaMessageType + messageSubType NebulaMessageSubType + callback packetCallback + packet []byte +} + +type packetCallback func(t NebulaMessageType, st NebulaMessageSubType, h *HostInfo, p, nb, out []byte) + +type HostInfoDest struct { + active bool + addr *udpAddr + //probes [ProbeLen]bool + probeCounter int +} + +type Probe struct { + Addr *net.UDPAddr + Counter int +} + +func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { + h := map[uint32]*HostInfo{} + i := map[uint32]*HostInfo{} + m := HostMap{ + name: name, + Indexes: i, + Hosts: h, + preferredRanges: preferredRanges, + vpnCIDR: vpnCIDR, + defaultRoute: 0, + } + return &m +} + +// UpdateStats takes a name and reports host and index counts to the stats collection system +func (hm *HostMap) EmitStats(name string) { + hm.RLock() + hostLen := len(hm.Hosts) + indexLen := len(hm.Indexes) + hm.RUnlock() + + metrics.GetOrRegisterGauge("hostmap."+name+".hosts", nil).Update(int64(hostLen)) + metrics.GetOrRegisterGauge("hostmap."+name+".indexes", nil).Update(int64(indexLen)) +} + +func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) { + hm.RLock() + if i, ok := hm.Hosts[vpnIP]; ok { + index := i.localIndexId + hm.RUnlock() + return index, nil + } + hm.RUnlock() + return 0, errors.New("vpn IP not found") +} + +func (hm *HostMap) GetVpnIPByIndex(index uint32) (uint32, error) { + hm.RLock() + if i, ok := hm.Indexes[index]; ok { + vpnIP := i.hostId + hm.RUnlock() + return vpnIP, nil + } + hm.RUnlock() + return 0, errors.New("vpn IP not found") +} + +func (hm *HostMap) Add(ip uint32, hostinfo *HostInfo) { + hm.Lock() + hm.Hosts[ip] = hostinfo + hm.Unlock() +} + +func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo { + h := &HostInfo{} + hm.RLock() + if _, ok := hm.Hosts[vpnIP]; !ok { + hm.RUnlock() + h = &HostInfo{ + Remotes: []*HostInfoDest{}, + promoteCounter: 0, + hostId: vpnIP, + HandshakePacket: make(map[uint8][]byte, 0), + } + hm.Lock() + hm.Hosts[vpnIP] = h + hm.Unlock() + return h + } else { + h = hm.Hosts[vpnIP] + hm.RUnlock() + return h + } +} + +func (hm *HostMap) DeleteVpnIP(vpnIP uint32) { + hm.Lock() + delete(hm.Hosts, vpnIP) + if len(hm.Hosts) == 0 { + hm.Hosts = map[uint32]*HostInfo{} + } + hm.Unlock() + + if l.Level >= logrus.DebugLevel { + l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}). + Debug("Hostmap vpnIp deleted") + } +} + +func (hm *HostMap) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) { + hm.Lock() + if _, ok := hm.Indexes[index]; !ok { + h := &HostInfo{ + ConnectionState: ci, + Remotes: []*HostInfoDest{}, + localIndexId: index, + HandshakePacket: make(map[uint8][]byte, 0), + } + hm.Indexes[index] = h + l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), + "hostinfo": m{"existing": false, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). + Debug("Hostmap index added") + + hm.Unlock() + return h, nil + } + hm.Unlock() + return nil, fmt.Errorf("refusing to overwrite existing index: %d", index) +} + +func (hm *HostMap) AddIndexHostInfo(index uint32, h *HostInfo) { + hm.Lock() + h.localIndexId = index + hm.Indexes[index] = h + hm.Unlock() + + if l.Level > logrus.DebugLevel { + l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), + "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). + Debug("Hostmap index added") + } +} + +func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) { + hm.Lock() + h.hostId = vpnIP + hm.Hosts[vpnIP] = h + hm.Unlock() + + if l.Level > logrus.DebugLevel { + l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts), + "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). + Debug("Hostmap vpnIp added") + } +} + +func (hm *HostMap) DeleteIndex(index uint32) { + hm.Lock() + delete(hm.Indexes, index) + if len(hm.Indexes) == 0 { + hm.Indexes = map[uint32]*HostInfo{} + } + hm.Unlock() + + if l.Level >= logrus.DebugLevel { + l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}). + Debug("Hostmap index deleted") + } +} + +func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) { + //TODO: we probably just want ot return bool instead of error, or at least a static error + hm.RLock() + if h, ok := hm.Indexes[index]; ok { + hm.RUnlock() + return h, nil + } else { + hm.RUnlock() + return nil, errors.New("unable to find index") + } +} + +// This function needs to range because we don't keep a map of remote indexes. +func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) { + hm.RLock() + for _, h := range hm.Indexes { + if h.ConnectionState != nil && h.remoteIndexId == index { + hm.RUnlock() + return h, nil + } + } + for _, h := range hm.Hosts { + if h.ConnectionState != nil && h.remoteIndexId == index { + hm.RUnlock() + return h, nil + } + } + hm.RUnlock() + return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name) +} + +func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo { + hm.Lock() + i, v := hm.Hosts[vpnIp] + if v { + i.AddRemote(*remote) + } else { + i = &HostInfo{ + Remotes: []*HostInfoDest{NewHostInfoDest(remote)}, + promoteCounter: 0, + hostId: vpnIp, + HandshakePacket: make(map[uint8][]byte, 0), + } + i.remote = i.Remotes[0].addr + hm.Hosts[vpnIp] = i + l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}). + Debug("Hostmap remote ip added") + } + i.ForcePromoteBest(hm.preferredRanges) + hm.Unlock() + return i +} + +func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) { + return hm.queryVpnIP(vpnIp, nil) +} + +// PromoteBestQueryVpnIP will attempt to lazily switch to the best remote every +// `PromoteEvery` calls to this function for a given host. +func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostInfo, error) { + return hm.queryVpnIP(vpnIp, ifce) +} + +func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) { + if hm.vpnCIDR.Contains(int2ip(vpnIp)) == false && hm.defaultRoute != 0 { + // FIXME: this shouldn't ship + d := hm.Hosts[hm.defaultRoute] + if d != nil { + return hm.Hosts[hm.defaultRoute], nil + } + } + hm.RLock() + if h, ok := hm.Hosts[vpnIp]; ok { + if promoteIfce != nil { + h.TryPromoteBest(hm.preferredRanges, promoteIfce) + } + //fmt.Println(h.remote) + hm.RUnlock() + return h, nil + } else { + //return &net.UDPAddr{}, nil, errors.New("Unable to find host") + hm.RUnlock() + /* + if lightHouse != nil { + lightHouse.Query(vpnIp) + return nil, errors.New("Unable to find host") + } + */ + return nil, errors.New("unable to find host") + } +} + +func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool { + hm.RLock() + if i, ok := hm.Hosts[vpnIP]; ok { + if i == nil { + hm.RUnlock() + return false + } + complete := i.HandshakeComplete + hm.RUnlock() + return complete + + } + hm.RUnlock() + return false +} + +func (hm *HostMap) CheckHandshakeCompleteIndex(index uint32) bool { + hm.RLock() + if i, ok := hm.Indexes[index]; ok { + if i == nil { + hm.RUnlock() + return false + } + complete := i.HandshakeComplete + hm.RUnlock() + return complete + + } + hm.RUnlock() + return false +} + +func (hm *HostMap) ClearRemotes(vpnIP uint32) { + hm.Lock() + i := hm.Hosts[vpnIP] + if i == nil { + hm.Unlock() + return + } + i.remote = nil + i.Remotes = nil + hm.Unlock() +} + +func (hm *HostMap) SetDefaultRoute(ip uint32) { + hm.defaultRoute = ip +} + +func (hm *HostMap) PunchList() []*udpAddr { + var list []*udpAddr + hm.RLock() + for _, v := range hm.Hosts { + for _, r := range v.Remotes { + list = append(list, r.addr) + } + // if h, ok := hm.Hosts[vpnIp]; ok { + // hm.Hosts[vpnIp].PromoteBest(hm.preferredRanges, false) + //fmt.Println(h.remote) + // } + } + hm.RUnlock() + return list +} + +func (hm *HostMap) Punchy(conn *udpConn) { + for { + for _, addr := range hm.PunchList() { + conn.WriteTo([]byte{1}, addr) + } + time.Sleep(time.Second * 30) + } +} + +func (i *HostInfo) MarshalJSON() ([]byte, error) { + return json.Marshal(m{ + "remote": i.remote, + "remotes": i.Remotes, + "promote_counter": i.promoteCounter, + "connection_state": i.ConnectionState, + "handshake_start": i.handshakeStart, + "handshake_ready": i.HandshakeReady, + "handshake_counter": i.HandshakeCounter, + "handshake_complete": i.HandshakeComplete, + "handshake_packet": i.HandshakePacket, + "packet_store": i.packetStore, + "remote_index": i.remoteIndexId, + "local_index": i.localIndexId, + "host_id": int2ip(i.hostId), + "receive_errors": i.recvError, + "last_roam": i.lastRoam, + "last_roam_remote": i.lastRoamRemote, + }) +} + +func (i *HostInfo) BindConnectionState(cs *ConnectionState) { + i.ConnectionState = cs +} + +func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { + if i.remote == nil { + i.ForcePromoteBest(preferredRanges) + return + } + + i.promoteCounter++ + if i.promoteCounter%PromoteEvery == 0 { + // return early if we are already on a preferred remote + rIP := udp2ip(i.remote) + for _, l := range preferredRanges { + if l.Contains(rIP) { + return + } + } + + // We re-query the lighthouse periodically while sending packets, so + // check for new remotes in our local lighthouse cache + ips := ifce.lightHouse.QueryCache(i.hostId) + for _, ip := range ips { + i.AddRemote(ip) + } + + best, preferred := i.getBestRemote(preferredRanges) + if preferred && !best.Equals(i.remote) { + // Try to send a test packet to that host, this should + // cause it to detect a roaming event and switch remotes + ifce.send(test, testRequest, i.ConnectionState, i, best, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + } + } +} + +func (i *HostInfo) ForcePromoteBest(preferredRanges []*net.IPNet) { + best, _ := i.getBestRemote(preferredRanges) + if best != nil { + i.remote = best + } +} + +func (i *HostInfo) getBestRemote(preferredRanges []*net.IPNet) (best *udpAddr, preferred bool) { + if len(i.Remotes) > 0 { + for _, r := range i.Remotes { + rIP := udp2ip(r.addr) + + for _, l := range preferredRanges { + if l.Contains(rIP) { + return r.addr, true + } + } + + if best == nil || !PrivateIP(rIP) { + best = r.addr + } + /* + for _, r := range i.Remotes { + // Must have > 80% probe success to be considered. + //fmt.Println("GRADE:", r.addr.IP, r.Grade()) + if r.Grade() > float64(.8) { + if localToMe.Contains(r.addr.IP) == true { + best = r.addr + break + //i.remote = i.Remotes[c].addr + } else { + //} + } + */ + } + return best, false + } + + return nil, false +} + +// rotateRemote will move remote to the next ip in the list of remote ips for this host +// This is different than PromoteBest in that what is algorithmically best may not actually work. +// Only known use case is when sending a stage 0 handshake. +// It may be better to just send stage 0 handshakes to all known ips and sort it out in the receiver. +func (i *HostInfo) rotateRemote() { + // We have 0, can't rotate + if len(i.Remotes) < 1 { + return + } + + if i.remote == nil { + i.remote = i.Remotes[0].addr + return + } + + // We want to look at all but the very last entry since that is handled at the end + for x := 0; x < len(i.Remotes)-1; x++ { + // Find our current position and move to the next one in the list + if i.Remotes[x].addr.Equals(i.remote) { + i.remote = i.Remotes[x+1].addr + return + } + } + + // Our current position was likely the last in the list, start over at 0 + i.remote = i.Remotes[0].addr +} + +func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) { + //TODO: return the error so we can log with more context + if len(i.packetStore) < 100 { + tempPacket := make([]byte, len(packet)) + copy(tempPacket, packet) + //l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket) + i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket}) + l.WithField("vpnIp", IntIp(i.hostId)). + WithField("length", len(i.packetStore)). + WithField("stored", true). + Debugf("Packet store") + + } else if l.Level >= logrus.DebugLevel { + l.WithField("vpnIp", IntIp(i.hostId)). + WithField("length", len(i.packetStore)). + WithField("stored", false). + Debugf("Packet store") + } +} + +// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets +func (i *HostInfo) handshakeComplete() { + //TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because: + //TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send + //TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical + + i.ConnectionState.queueLock.Lock() + i.HandshakeComplete = true + //TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen. + // Clamping it to 2 gets us out of the woods for now + *i.ConnectionState.messageCounter = 2 + l.WithField("vpnIp", IntIp(i.hostId)).Debugf("Sending %d stored packets", len(i.packetStore)) + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + for _, cp := range i.packetStore { + cp.callback(cp.messageType, cp.messageSubType, i, cp.packet, nb, out) + } + i.packetStore = make([]*cachedPacket, 0) + i.ConnectionState.ready = true + i.ConnectionState.queueLock.Unlock() + i.ConnectionState.certState = nil +} + +func (i *HostInfo) RemoteUDPAddrs() []*udpAddr { + var addrs []*udpAddr + for _, r := range i.Remotes { + addrs = append(addrs, r.addr) + } + return addrs +} + +func (i *HostInfo) GetCert() *cert.NebulaCertificate { + if i.ConnectionState != nil { + return i.ConnectionState.peerCert + } + return nil +} + +func (i *HostInfo) AddRemote(r udpAddr) *udpAddr { + remote := &r + //add := true + for _, r := range i.Remotes { + if r.addr.Equals(remote) { + return r.addr + //add = false + } + } + // Trim this down if necessary + if len(i.Remotes) > MaxRemotes { + i.Remotes = i.Remotes[len(i.Remotes)-MaxRemotes:] + } + i.Remotes = append(i.Remotes, NewHostInfoDest(remote)) + return remote + //l.Debugf("Added remote %s for vpn ip", remote) +} + +func (i *HostInfo) SetRemote(remote udpAddr) { + i.remote = i.AddRemote(remote) +} + +func (i *HostInfo) ClearRemotes() { + i.remote = nil + i.Remotes = []*HostInfoDest{} +} + +func (i *HostInfo) ClearConnectionState() { + i.ConnectionState = nil +} + +func (i *HostInfo) RecvErrorExceeded() bool { + if i.recvError < 3 { + i.recvError += 1 + return false + } + return true +} + +//######################## + +func NewHostInfoDest(addr *udpAddr) *HostInfoDest { + i := &HostInfoDest{ + addr: addr, + } + return i +} + +func (hid *HostInfoDest) MarshalJSON() ([]byte, error) { + return json.Marshal(m{ + "active": hid.active, + "address": hid.addr, + "probe_count": hid.probeCounter, + }) +} + +/* + +func (hm *HostMap) DebugRemotes(vpnIp uint32) string { + s := "\n" + for _, h := range hm.Hosts { + for _, r := range h.Remotes { + s += fmt.Sprintf("%s : %d ## %v\n", r.addr.IP.String(), r.addr.Port, r.probes) + } + } + return s +} + + +func (d *HostInfoDest) Grade() float64 { + c1 := ProbeLen + for n := len(d.probes) - 1; n >= 0; n-- { + if d.probes[n] == true { + c1 -= 1 + } + } + return float64(c1) / float64(ProbeLen) +} + +func (d *HostInfoDest) Grade() (float64, float64, float64) { + c1 := ProbeLen + c2 := ProbeLen / 2 + c2c := ProbeLen - ProbeLen/2 + c3 := ProbeLen / 5 + c3c := ProbeLen - ProbeLen/5 + for n := len(d.probes) - 1; n >= 0; n-- { + if d.probes[n] == true { + c1 -= 1 + if n >= c2c { + c2 -= 1 + if n >= c3c { + c3 -= 1 + } + } + } + //if n >= d { + } + return float64(c3) / float64(ProbeLen/5), float64(c2) / float64(ProbeLen/2), float64(c1) / float64(ProbeLen) + //return float64(c1) / float64(ProbeLen), float64(c2) / float64(ProbeLen/2), float64(c3) / float64(ProbeLen/5) +} + + +func (i *HostInfo) HandleReply(addr *net.UDPAddr, counter int) { + for _, r := range i.Remotes { + if r.addr.IP.Equal(addr.IP) && r.addr.Port == addr.Port { + r.ProbeReceived(counter) + } + } +} + +func (i *HostInfo) Probes() []*Probe { + p := []*Probe{} + for _, d := range i.Remotes { + p = append(p, &Probe{Addr: d.addr, Counter: d.Probe()}) + } + return p +} + + +func (d *HostInfoDest) Probe() int { + //d.probes = append(d.probes, true) + d.probeCounter++ + d.probes[d.probeCounter%ProbeLen] = true + return d.probeCounter + //return d.probeCounter +} + +func (d *HostInfoDest) ProbeReceived(probeCount int) { + if probeCount >= (d.probeCounter - ProbeLen) { + //fmt.Println("PROBE WORKED", probeCount) + //fmt.Println(d.addr, d.Grade()) + d.probes[probeCount%ProbeLen] = false + } +} + +*/ + +// Utility functions + +func localIps() *[]net.IP { + //FIXME: This function is pretty garbage + var ips []net.IP + ifaces, _ := net.Interfaces() + for _, i := range ifaces { + addrs, _ := i.Addrs() + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + //continue + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + if ip.To4() != nil && ip.IsLoopback() == false { + ips = append(ips, ip) + } + } + } + return &ips +} + +func PrivateIP(ip net.IP) bool { + private := false + _, private24BitBlock, _ := net.ParseCIDR("10.0.0.0/8") + _, private20BitBlock, _ := net.ParseCIDR("172.16.0.0/12") + _, private16BitBlock, _ := net.ParseCIDR("192.168.0.0/16") + private = private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip) + return private +} diff --git a/hostmap_test.go b/hostmap_test.go new file mode 100644 index 0000000..de5e198 --- /dev/null +++ b/hostmap_test.go @@ -0,0 +1,166 @@ +package nebula + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +/* +func TestHostInfoDestProbe(t *testing.T) { + a, _ := net.ResolveUDPAddr("udp", "1.0.0.1:22222") + d := NewHostInfoDest(a) + + // 999 probes that all return should give a 100% success rate + for i := 0; i < 999; i++ { + meh := d.Probe() + d.ProbeReceived(meh) + } + assert.Equal(t, d.Grade(), float64(1)) + + // 999 probes of which only half return should give a 50% success rate + for i := 0; i < 999; i++ { + meh := d.Probe() + if i%2 == 0 { + d.ProbeReceived(meh) + } + } + assert.Equal(t, d.Grade(), float64(.5)) + + // 999 probes of which none return should give a 0% success rate + for i := 0; i < 999; i++ { + d.Probe() + } + assert.Equal(t, d.Grade(), float64(0)) + + // 999 probes of which only 1/4 return should give a 25% success rate + for i := 0; i < 999; i++ { + meh := d.Probe() + if i%4 == 0 { + d.ProbeReceived(meh) + } + } + assert.Equal(t, d.Grade(), float64(.25)) + + // 999 probes of which only half return and are duplicates should give a 50% success rate + for i := 0; i < 999; i++ { + meh := d.Probe() + if i%2 == 0 { + d.ProbeReceived(meh) + d.ProbeReceived(meh) + } + } + assert.Equal(t, d.Grade(), float64(.5)) + + // 999 probes of which only way old replies return should give a 0% success rate + for i := 0; i < 999; i++ { + meh := d.Probe() + d.ProbeReceived(meh - 101) + } + assert.Equal(t, d.Grade(), float64(0)) + +} +*/ + +func TestHostmap(t *testing.T) { + _, myNet, _ := net.ParseCIDR("10.128.0.0/16") + _, localToMe, _ := net.ParseCIDR("192.168.1.0/24") + myNets := []*net.IPNet{myNet} + preferredRanges := []*net.IPNet{localToMe} + + m := NewHostMap("test", myNet, preferredRanges) + + a := NewUDPAddrFromString("10.127.0.3:11111") + b := NewUDPAddrFromString("1.0.0.1:22222") + y := NewUDPAddrFromString("10.128.0.3:11111") + m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a) + m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), b) + m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y) + + info, _ := m.QueryVpnIP(ip2int(net.ParseIP("127.0.0.1"))) + + // There should be three remotes in the host map + assert.Equal(t, 3, len(info.Remotes)) + + // Adding an identical remote should not change the count + m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y) + assert.Equal(t, 3, len(info.Remotes)) + + // Adding a fresh remote should add one + y = NewUDPAddrFromString("10.18.0.3:11111") + m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y) + assert.Equal(t, 4, len(info.Remotes)) + + // Query and reference remote should get the first one (and not nil) + info, _ = m.QueryVpnIP(ip2int(net.ParseIP("127.0.0.1"))) + assert.NotNil(t, info.remote) + + // Promotion should ensure that the best remote is chosen (y) + info.ForcePromoteBest(myNets) + assert.True(t, myNet.Contains(udp2ip(info.remote))) + +} + +func TestHostmapdebug(t *testing.T) { + _, myNet, _ := net.ParseCIDR("10.128.0.0/16") + _, localToMe, _ := net.ParseCIDR("192.168.1.0/24") + preferredRanges := []*net.IPNet{localToMe} + m := NewHostMap("test", myNet, preferredRanges) + + a := NewUDPAddrFromString("10.127.0.3:11111") + b := NewUDPAddrFromString("1.0.0.1:22222") + y := NewUDPAddrFromString("10.128.0.3:11111") + m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a) + m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), b) + m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y) + + //t.Errorf("%s", m.DebugRemotes(1)) +} + +func TestHostMap_rotateRemote(t *testing.T) { + h := HostInfo{} + // 0 remotes, no panic + h.rotateRemote() + assert.Nil(t, h.remote) + + // 1 remote, no panic + h.AddRemote(*NewUDPAddr(ip2int(net.IP{1, 1, 1, 1}), 0)) + h.rotateRemote() + assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 1})) + + h.AddRemote(*NewUDPAddr(ip2int(net.IP{1, 1, 1, 2}), 0)) + h.AddRemote(*NewUDPAddr(ip2int(net.IP{1, 1, 1, 3}), 0)) + h.AddRemote(*NewUDPAddr(ip2int(net.IP{1, 1, 1, 4}), 0)) + + // Rotate through those 3 + h.rotateRemote() + assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 2})) + + h.rotateRemote() + assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 3})) + + h.rotateRemote() + assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 4})) + + // Finally, we should start over + h.rotateRemote() + assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 1})) +} + +func BenchmarkHostmappromote2(b *testing.B) { + for n := 0; n < b.N; n++ { + _, myNet, _ := net.ParseCIDR("10.128.0.0/16") + _, localToMe, _ := net.ParseCIDR("192.168.1.0/24") + preferredRanges := []*net.IPNet{localToMe} + m := NewHostMap("test", myNet, preferredRanges) + y := NewUDPAddrFromString("10.128.0.3:11111") + a := NewUDPAddrFromString("10.127.0.3:11111") + g := NewUDPAddrFromString("1.0.0.1:22222") + m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a) + m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), g) + m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y) + } + b.Errorf("hi") + +} diff --git a/inside.go b/inside.go new file mode 100644 index 0000000..34022aa --- /dev/null +++ b/inside.go @@ -0,0 +1,201 @@ +package nebula + +import ( + "sync/atomic" + + "github.com/flynn/noise" + "github.com/sirupsen/logrus" +) + +func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte) { + err := newPacket(packet, false, fwPacket) + if err != nil { + l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) + return + } + + // Ignore local broadcast packets + if f.dropLocalBroadcast && fwPacket.RemoteIP == f.localBroadcast { + return + } + + // Ignore broadcast packets + if f.dropMulticast && isMulticast(fwPacket.RemoteIP) { + return + } + + hostinfo := f.getOrHandshake(fwPacket.RemoteIP) + ci := hostinfo.ConnectionState + + if ci.ready == false { + // Because we might be sending stored packets, lock here to stop new things going to + // the packet queue. + ci.queueLock.Lock() + if !ci.ready { + hostinfo.cachePacket(message, 0, packet, f.sendMessageNow) + ci.queueLock.Unlock() + return + } + ci.queueLock.Unlock() + } + + if !f.firewall.Drop(packet, *fwPacket, false, ci.peerCert, trustedCAs) { + f.send(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out) + if f.lightHouse != nil && *ci.messageCounter%5000 == 0 { + f.lightHouse.Query(fwPacket.RemoteIP, f) + } + + } else if l.Level >= logrus.DebugLevel { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket). + Debugln("dropping outbound packet") + } +} + +func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo { + hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f) + + //if err != nil || hostinfo.ConnectionState == nil { + if err != nil { + hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIP(vpnIp) + if err != nil { + hostinfo = f.handshakeManager.AddVpnIP(vpnIp) + } + } + + ci := hostinfo.ConnectionState + + if ci != nil && ci.eKey != nil && ci.ready { + return hostinfo + } + + if ci == nil { + // if we don't have a connection state, then send a handshake initiation + ci = f.newConnectionState(true, noise.HandshakeIX, []byte{}, 0) + // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. + //ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0) + hostinfo.ConnectionState = ci + } else if ci.eKey == nil { + // if we don't have any state at all, create it + } + + // If we have already created the handshake packet, we don't want to call the function at all. + if !hostinfo.HandshakeReady { + ixHandshakeStage0(f, vpnIp, hostinfo) + // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. + //xx_handshakeStage0(f, ip, hostinfo) + } + + return hostinfo +} + +func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) { + fp := &FirewallPacket{} + err := newPacket(p, false, fp) + if err != nil { + l.Warnf("error while parsing outgoing packet for firewall check; %v", err) + return + } + + // check if packet is in outbound fw rules + if f.firewall.Drop(p, *fp, false, hostInfo.ConnectionState.peerCert, trustedCAs) { + l.WithField("fwPacket", fp).Debugln("dropping cached packet") + return + } + + f.send(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out) + if f.lightHouse != nil && *hostInfo.ConnectionState.messageCounter%5000 == 0 { + f.lightHouse.Query(fp.RemoteIP, f) + } +} + +// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp +func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { + hostInfo := f.getOrHandshake(vpnIp) + + if !hostInfo.ConnectionState.ready { + // Because we might be sending stored packets, lock here to stop new things going to + // the packet queue. + hostInfo.ConnectionState.queueLock.Lock() + if !hostInfo.ConnectionState.ready { + hostInfo.cachePacket(t, st, p, f.sendMessageToVpnIp) + hostInfo.ConnectionState.queueLock.Unlock() + return + } + hostInfo.ConnectionState.queueLock.Unlock() + } + + f.sendMessageToVpnIp(t, st, hostInfo, p, nb, out) + return +} + +func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) { + f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out) +} + +// SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp +func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { + hostInfo := f.getOrHandshake(vpnIp) + + if hostInfo.ConnectionState.ready == false { + // Because we might be sending stored packets, lock here to stop new things going to + // the packet queue. + hostInfo.ConnectionState.queueLock.Lock() + if !hostInfo.ConnectionState.ready { + hostInfo.cachePacket(t, st, p, f.sendMessageToAll) + hostInfo.ConnectionState.queueLock.Unlock() + return + } + hostInfo.ConnectionState.queueLock.Unlock() + } + + f.sendMessageToAll(t, st, hostInfo, p, nb, out) + return +} + +func (f *Interface) sendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, b []byte) { + for _, r := range hostInfo.RemoteUDPAddrs() { + f.send(t, st, hostInfo.ConnectionState, hostInfo, r, p, nb, b) + } +} + +func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) { + if ci.eKey == nil { + //TODO: log warning + return + } + + var err error + //TODO: enable if we do more than 1 tun queue + //ci.writeLock.Lock() + c := atomic.AddUint64(ci.messageCounter, 1) + + //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) + out = HeaderEncode(out, Version, uint8(t), uint8(st), hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo.hostId) + + out, err = ci.eKey.EncryptDanger(out, out, p, c, nb) + //TODO: see above note on lock + //ci.writeLock.Unlock() + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)). + WithField("udpAddr", remote).WithField("counter", c). + WithField("attemptedCounter", ci.messageCounter). + Error("Failed to encrypt outgoing packet") + return + } + + err = f.outside.WriteTo(out, remote) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)). + WithField("udpAddr", remote).Error("Failed to write outgoing packet") + } +} + +func isMulticast(ip uint32) bool { + // Class D multicast + if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 { + return true + } + + return false +} diff --git a/interface.go b/interface.go new file mode 100644 index 0000000..2cbb776 --- /dev/null +++ b/interface.go @@ -0,0 +1,277 @@ +package nebula + +import ( + "crypto/sha256" + "errors" + "fmt" + "io" + "os" + "time" + + "github.com/rcrowley/go-metrics" + "golang.org/x/crypto/hkdf" +) + +const mtu = 9001 + +type InterfaceConfig struct { + HostMap *HostMap + Outside *udpConn + Inside *Tun + certState *CertState + Cipher string + Firewall *Firewall + ServeDns bool + HandshakeManager *HandshakeManager + lightHouse *LightHouse + checkInterval int + pendingDeletionInterval int + handshakeMACKey string + handshakeAcceptedMACKeys []string + DropLocalBroadcast bool + DropMulticast bool + UDPBatchSize int +} + +type Interface struct { + hostMap *HostMap + outside *udpConn + inside *Tun + certState *CertState + cipher string + firewall *Firewall + connectionManager *connectionManager + handshakeManager *HandshakeManager + serveDns bool + createTime time.Time + lightHouse *LightHouse + handshakeMACKey []byte + handshakeAcceptedMACKeys [][]byte + localBroadcast uint32 + dropLocalBroadcast bool + dropMulticast bool + udpBatchSize int + version string + + metricRxRecvError metrics.Counter + metricTxRecvError metrics.Counter + metricHandshakes metrics.Histogram +} + +func NewInterface(c *InterfaceConfig) (*Interface, error) { + if c.Outside == nil { + return nil, errors.New("no outside connection") + } + if c.Inside == nil { + return nil, errors.New("no inside interface (tun)") + } + if c.certState == nil { + return nil, errors.New("no certificate state") + } + if c.Firewall == nil { + return nil, errors.New("no firewall rules") + } + + // Use KDF to make this useful + hmacKey, err := sha256KdfFromString(c.handshakeMACKey) + if err != nil { + l.Debugln(err) + } + + allowedMacs := make([][]byte, 0) + //allowedMacs = append(allowedMacs, mac) + if len(c.handshakeAcceptedMACKeys) > 0 { + for _, k := range c.handshakeAcceptedMACKeys { + // Use KDF to make these useful too + hmacKey, err := sha256KdfFromString(k) + if err != nil { + l.Debugln(err) + } + allowedMacs = append(allowedMacs, hmacKey) + } + } else { + if len(c.handshakeMACKey) > 0 { + l.Warnln("You have set an outgoing MAC but do not accept any incoming. This is probably not what you want.") + } else { + // This else is a fallback if we have not set any mac keys at all + hmacKey, err := sha256KdfFromString("") + if err != nil { + l.Debugln(err) + } + allowedMacs = append(allowedMacs, hmacKey) + + } + } + + ifce := &Interface{ + hostMap: c.HostMap, + outside: c.Outside, + inside: c.Inside, + certState: c.certState, + cipher: c.Cipher, + firewall: c.Firewall, + serveDns: c.ServeDns, + handshakeManager: c.HandshakeManager, + createTime: time.Now(), + lightHouse: c.lightHouse, + handshakeMACKey: hmacKey, + handshakeAcceptedMACKeys: allowedMacs, + localBroadcast: ip2int(c.certState.certificate.Details.Ips[0].IP) | ^ip2int(c.certState.certificate.Details.Ips[0].Mask), + dropLocalBroadcast: c.DropLocalBroadcast, + dropMulticast: c.DropMulticast, + udpBatchSize: c.UDPBatchSize, + + metricRxRecvError: metrics.GetOrRegisterCounter("messages.rx.recv_error", nil), + metricTxRecvError: metrics.GetOrRegisterCounter("messages.tx.recv_error", nil), + metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), + } + + ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval) + + return ifce, nil +} + +func (f *Interface) Run(tunRoutines, udpRoutines int, buildVersion string) { + // actually turn on tun dev + if err := f.inside.Activate(); err != nil { + l.Fatal(err) + } + + f.version = buildVersion + l.WithField("interface", f.inside.Device).WithField("network", f.inside.Cidr.String()). + WithField("build", buildVersion). + Info("Nebula interface is active") + + // Launch n queues to read packets from udp + for i := 0; i < udpRoutines; i++ { + go f.listenOut(i) + } + + // Launch n queues to read packets from tun dev + for i := 0; i < tunRoutines; i++ { + go f.listenIn(i) + } +} + +func (f *Interface) listenOut(i int) { + //TODO: handle error + addr, err := f.outside.LocalAddr() + if err != nil { + l.WithError(err).Error("failed to discover udp listening address") + } + + var li *udpConn + if i > 0 { + //TODO: handle error + li, err = NewListener(udp2ip(addr).String(), int(addr.Port), i > 0) + if err != nil { + l.WithError(err).Error("failed to make a new udp listener") + } + } else { + li = f.outside + } + + li.ListenOut(f) +} + +func (f *Interface) listenIn(i int) { + packet := make([]byte, mtu) + out := make([]byte, mtu) + fwPacket := &FirewallPacket{} + nb := make([]byte, 12, 12) + + for { + n, err := f.inside.Read(packet) + if err != nil { + l.WithError(err).Error("Error while reading outbound packet") + // This only seems to happen when something fatal happens to the fd, so exit. + os.Exit(2) + } + + f.consumeInsidePacket(packet[:n], fwPacket, nb, out) + } +} + +func (f *Interface) RegisterConfigChangeCallbacks(c *Config) { + c.RegisterReloadCallback(f.reloadCA) + c.RegisterReloadCallback(f.reloadCertKey) + c.RegisterReloadCallback(f.reloadFirewall) + c.RegisterReloadCallback(f.outside.reloadConfig) +} + +func (f *Interface) reloadCA(c *Config) { + // reload and check regardless + // todo: need mutex? + newCAs, err := loadCAFromConfig(c) + if err != nil { + l.WithError(err).Error("Could not refresh trusted CA certificates") + return + } + + trustedCAs = newCAs + l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed") +} + +func (f *Interface) reloadCertKey(c *Config) { + // reload and check in all cases + cs, err := NewCertStateFromConfig(c) + if err != nil { + l.WithError(err).Error("Could not refresh client cert") + return + } + + // did IP in cert change? if so, don't set + oldIPs := f.certState.certificate.Details.Ips + newIPs := cs.certificate.Details.Ips + if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { + l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old") + return + } + + f.certState = cs + l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk") +} + +func (f *Interface) reloadFirewall(c *Config) { + //TODO: need to trigger/detect if the certificate changed too + if c.HasChanged("firewall") == false { + l.Debug("No firewall config change detected") + return + } + + fw, err := NewFirewallFromConfig(f.certState.certificate, c) + if err != nil { + l.WithError(err).Error("Error while creating firewall during reload") + return + } + + oldFw := f.firewall + f.firewall = fw + + oldFw.Destroy() + l.WithField("firewallHash", fw.GetRuleHash()). + WithField("oldFirewallHash", oldFw.GetRuleHash()). + Info("New firewall has been installed") +} + +func (f *Interface) emitStats(i time.Duration) { + ticker := time.NewTicker(i) + for range ticker.C { + f.firewall.EmitStats() + f.handshakeManager.EmitStats() + } +} + +func sha256KdfFromString(secret string) ([]byte, error) { + // Use KDF to make this useful + mac := []byte(secret) + hmacKey := make([]byte, sha256.BlockSize) + hash := sha256.New + hkdfer := hkdf.New(hash, []byte(mac), nil, nil) + n, err := io.ReadFull(hkdfer, hmacKey) + if n != len(hmacKey) || err != nil { + l.Errorln("KDF Failed!") + return nil, fmt.Errorf("%s", err) + } + return hmacKey, nil +} diff --git a/lighthouse.go b/lighthouse.go new file mode 100644 index 0000000..5e19d65 --- /dev/null +++ b/lighthouse.go @@ -0,0 +1,368 @@ +package nebula + +import ( + "fmt" + "net" + "sync" + "time" + + "github.com/golang/protobuf/proto" + "github.com/slackhq/nebula/cert" +) + +type LightHouse struct { + sync.RWMutex //Because we concurrently read and write to our maps + amLighthouse bool + myIp uint32 + punchConn *udpConn + + // Local cache of answers from light houses + addrMap map[uint32][]udpAddr + + // staticList exists to avoid having a bool in each addrMap entry + // since static should be rare + staticList map[uint32]struct{} + lighthouses map[uint32]struct{} + interval int + nebulaPort int + punchBack bool +} + +type EncWriter interface { + SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) + SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) +} + +func NewLightHouse(amLighthouse bool, myIp uint32, ips []string, interval int, nebulaPort int, pc *udpConn, punchBack bool) *LightHouse { + h := LightHouse{ + amLighthouse: amLighthouse, + myIp: myIp, + addrMap: make(map[uint32][]udpAddr), + nebulaPort: nebulaPort, + lighthouses: make(map[uint32]struct{}), + staticList: make(map[uint32]struct{}), + interval: interval, + punchConn: pc, + punchBack: punchBack, + } + + for _, rIp := range ips { + h.lighthouses[ip2int(net.ParseIP(rIp))] = struct{}{} + } + + return &h +} + +func (lh *LightHouse) Query(ip uint32, f EncWriter) ([]udpAddr, error) { + if !lh.IsLighthouseIP(ip) { + lh.QueryServer(ip, f) + } + lh.RLock() + if v, ok := lh.addrMap[ip]; ok { + lh.RUnlock() + return v, nil + } + lh.RUnlock() + return nil, fmt.Errorf("host %s not known, queries sent to lighthouses", IntIp(ip)) +} + +// This is asynchronous so no reply should be expected +func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) { + if !lh.amLighthouse { + // Send a query to the lighthouses and hope for the best next time + query, err := proto.Marshal(NewLhQueryByInt(ip)) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload") + return + } + + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + for n := range lh.lighthouses { + f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out) + } + } +} + +// Query our local lighthouse cached results +func (lh *LightHouse) QueryCache(ip uint32) []udpAddr { + lh.RLock() + if v, ok := lh.addrMap[ip]; ok { + lh.RUnlock() + return v + } + lh.RUnlock() + return nil +} + +func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) { + // First we check the static mapping + // and do nothing if it is there + if _, ok := lh.staticList[vpnIP]; ok { + return + } + lh.Lock() + //l.Debugln(lh.addrMap) + delete(lh.addrMap, vpnIP) + l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP)) + lh.Unlock() +} + +func (lh *LightHouse) AddRemote(vpnIP uint32, toIp *udpAddr, static bool) { + // First we check if the sender thinks this is a static entry + // and do nothing if it is not, but should be considered static + if static == false { + if _, ok := lh.staticList[vpnIP]; ok { + return + } + } + + lh.Lock() + for _, v := range lh.addrMap[vpnIP] { + if v.Equals(toIp) { + lh.Unlock() + return + } + } + //l.Debugf("Adding reply of %s as %s\n", IntIp(vpnIP), toIp) + if static { + lh.staticList[vpnIP] = struct{}{} + } + lh.addrMap[vpnIP] = append(lh.addrMap[vpnIP], *toIp) + lh.Unlock() +} + +func (lh *LightHouse) AddRemoteAndReset(vpnIP uint32, toIp *udpAddr) { + if lh.amLighthouse { + lh.DeleteVpnIP(vpnIP) + lh.AddRemote(vpnIP, toIp, false) + } + +} + +func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool { + if _, ok := lh.lighthouses[vpnIP]; ok { + return true + } + return false +} + +// Quick generators for protobuf + +func NewLhQueryByIpString(VpnIp string) *NebulaMeta { + return NewLhQueryByInt(ip2int(net.ParseIP(VpnIp))) +} + +func NewLhQueryByInt(VpnIp uint32) *NebulaMeta { + return &NebulaMeta{ + Type: NebulaMeta_HostQuery, + Details: &NebulaMetaDetails{ + VpnIp: VpnIp, + }, + } +} + +func NewLhWhoami() *NebulaMeta { + return &NebulaMeta{ + Type: NebulaMeta_HostWhoami, + Details: &NebulaMetaDetails{}, + } +} + +// End Quick generators for protobuf + +func NewIpAndPortFromUDPAddr(addr udpAddr) *IpAndPort { + return &IpAndPort{Ip: udp2ipInt(&addr), Port: uint32(addr.Port)} +} + +func NewIpAndPortsFromNetIps(ips []udpAddr) *[]*IpAndPort { + var iap []*IpAndPort + for _, e := range ips { + // Only add IPs that aren't my VPN/tun IP + iap = append(iap, NewIpAndPortFromUDPAddr(e)) + } + return &iap +} + +func (lh *LightHouse) LhUpdateWorker(f EncWriter) { + if lh.amLighthouse { + return + } + + for { + ipp := []*IpAndPort{} + + for _, e := range *localIps() { + // Only add IPs that aren't my VPN/tun IP + if ip2int(e) != lh.myIp { + ipp = append(ipp, &IpAndPort{Ip: ip2int(e), Port: uint32(lh.nebulaPort)}) + //fmt.Println(e) + } + } + m := &NebulaMeta{ + Type: NebulaMeta_HostUpdateNotification, + Details: &NebulaMetaDetails{ + VpnIp: lh.myIp, + IpAndPorts: ipp, + }, + } + + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + for vpnIp := range lh.lighthouses { + mm, err := proto.Marshal(m) + if err != nil { + l.Debugf("Invalid marshal to update") + } + //l.Error("LIGHTHOUSE PACKET SEND", mm) + f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out) + + } + time.Sleep(time.Second * time.Duration(lh.interval)) + } +} + +func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *cert.NebulaCertificate, f EncWriter) { + n := &NebulaMeta{} + err := proto.Unmarshal(p, n) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr). + Error("Failed to unmarshal lighthouse packet") + //TODO: send recv_error? + return + } + + if n.Details == nil { + l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr). + Error("Invalid lighthouse update") + //TODO: send recv_error? + return + } + + switch n.Type { + case NebulaMeta_HostQuery: + // Exit if we don't answer queries + if !lh.amLighthouse { + l.Debugln("I don't answer queries, but received from: ", rAddr) + return + } + + //l.Debugln("Got Query") + ips, err := lh.Query(n.Details.VpnIp, f) + if err != nil { + //l.Debugf("Can't answer query %s from %s because error: %s", IntIp(n.Details.VpnIp), rAddr, err) + return + } else { + iap := NewIpAndPortsFromNetIps(ips) + answer := &NebulaMeta{ + Type: NebulaMeta_HostQueryReply, + Details: &NebulaMetaDetails{ + VpnIp: n.Details.VpnIp, + IpAndPorts: *iap, + }, + } + reply, err := proto.Marshal(answer) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply") + return + } + f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, make([]byte, 12, 12), make([]byte, mtu)) + + // This signals the other side to punch some zero byte udp packets + ips, err = lh.Query(vpnIp, f) + if err != nil { + l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch") + return + } else { + //l.Debugln("Notify host to punch", iap) + iap = NewIpAndPortsFromNetIps(ips) + answer = &NebulaMeta{ + Type: NebulaMeta_HostPunchNotification, + Details: &NebulaMetaDetails{ + VpnIp: vpnIp, + IpAndPorts: *iap, + }, + } + reply, _ := proto.Marshal(answer) + f.SendMessageToVpnIp(lightHouse, 0, n.Details.VpnIp, reply, make([]byte, 12, 12), make([]byte, mtu)) + } + //fmt.Println(reply, remoteaddr) + } + + case NebulaMeta_HostQueryReply: + if !lh.IsLighthouseIP(vpnIp) { + return + } + for _, a := range n.Details.IpAndPorts { + //first := n.Details.IpAndPorts[0] + ans := NewUDPAddr(a.Ip, uint16(a.Port)) + lh.AddRemote(n.Details.VpnIp, ans, false) + } + + case NebulaMeta_HostUpdateNotification: + //Simple check that the host sent this not someone else + if n.Details.VpnIp != vpnIp { + l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update") + return + } + for _, a := range n.Details.IpAndPorts { + ans := NewUDPAddr(a.Ip, uint16(a.Port)) + lh.AddRemote(n.Details.VpnIp, ans, false) + } + case NebulaMeta_HostMovedNotification: + case NebulaMeta_HostPunchNotification: + if !lh.IsLighthouseIP(vpnIp) { + return + } + + empty := []byte{0} + for _, a := range n.Details.IpAndPorts { + vpnPeer := NewUDPAddr(a.Ip, uint16(a.Port)) + go func() { + for i := 0; i < 5; i++ { + lh.punchConn.WriteTo(empty, vpnPeer) + time.Sleep(time.Second * 1) + } + + }() + l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) + } + // This sends a nebula test packet to the host trying to contact us. In the case + // of a double nat or other difficult scenario, this may help establish + // a tunnel. + if lh.punchBack { + go func() { + time.Sleep(time.Second * 5) + l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp)) + f.SendMessageToVpnIp(test, testRequest, n.Details.VpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + }() + } + } +} + +/* +func (f *Interface) sendPathCheck(ci *ConnectionState, endpoint *net.UDPAddr, counter int) { + c := ci.messageCounter + b := HeaderEncode(nil, Version, uint8(path_check), 0, ci.remoteIndex, c) + ci.messageCounter++ + + if ci.eKey != nil { + msg := ci.eKey.EncryptDanger(b, nil, []byte(strconv.Itoa(counter)), c) + //msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c) + f.outside.WriteTo(msg, endpoint) + l.Debugf("path_check sent, remote index: %d, pathCounter %d", ci.remoteIndex, counter) + } +} + +func (f *Interface) sendPathCheckReply(ci *ConnectionState, endpoint *net.UDPAddr, counter []byte) { + c := ci.messageCounter + b := HeaderEncode(nil, Version, uint8(path_check_reply), 0, ci.remoteIndex, c) + ci.messageCounter++ + + if ci.eKey != nil { + msg := ci.eKey.EncryptDanger(b, nil, counter, c) + f.outside.WriteTo(msg, endpoint) + l.Debugln("path_check sent, remote index: ", ci.remoteIndex) + } +} +*/ diff --git a/lighthouse_test.go b/lighthouse_test.go new file mode 100644 index 0000000..b0ff492 --- /dev/null +++ b/lighthouse_test.go @@ -0,0 +1,76 @@ +package nebula + +import ( + "net" + "testing" + + proto "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" +) + +func TestNewLhQuery(t *testing.T) { + myIp := net.ParseIP("192.1.1.1") + myIpint := ip2int(myIp) + + // Generating a new lh query should work + a := NewLhQueryByInt(myIpint) + + // The result should be a nebulameta protobuf + assert.IsType(t, &NebulaMeta{}, a) + + // It should also Marshal fine + b, err := proto.Marshal(a) + assert.Nil(t, err) + + // and then Unmarshal fine + n := &NebulaMeta{} + err = proto.Unmarshal(b, n) + assert.Nil(t, err) + +} + +func TestNewipandportfromudpaddr(t *testing.T) { + blah := NewUDPAddrFromString("1.2.2.3:12345") + meh := NewIpAndPortFromUDPAddr(*blah) + assert.Equal(t, uint32(16908803), meh.Ip) + assert.Equal(t, uint32(12345), meh.Port) +} + +func TestNewipandportsfromudpaddrs(t *testing.T) { + blah := NewUDPAddrFromString("1.2.2.3:12345") + blah2 := NewUDPAddrFromString("9.9.9.9:47828") + group := []udpAddr{*blah, *blah2} + hah := NewIpAndPortsFromNetIps(group) + assert.IsType(t, &[]*IpAndPort{}, hah) + //t.Error(reflect.TypeOf(hah)) + +} + +/* +func TestLHQuery(t *testing.T) { + //n := NewLhQueryByIpString("10.128.0.3") + _, myNet, _ := net.ParseCIDR("10.128.0.0/16") + m := NewHostMap(myNet) + y, _ := net.ResolveUDPAddr("udp", "10.128.0.3:11111") + m.Add(ip2int(net.ParseIP("127.0.0.1")), y) + //t.Errorf("%s", m) + _ = m + + _, n, _ := net.ParseCIDR("127.0.0.1/8") + + /*udpServer, err := net.ListenUDP("udp", &net.UDPAddr{Port: 10009}) + if err != nil { + t.Errorf("%s", err) + } + + meh := NewLightHouse(n, m, []string{"10.128.0.2"}, false, 10, 10003, 10004) + //t.Error(m.Hosts) + meh2, err := meh.Query(ip2int(net.ParseIP("10.128.0.3"))) + t.Error(err) + if err != nil { + return + } + t.Errorf("%s", meh2) + t.Errorf("%s", n) +} +*/ diff --git a/main.go b/main.go new file mode 100644 index 0000000..fa739ca --- /dev/null +++ b/main.go @@ -0,0 +1,321 @@ +package nebula + +import ( + "encoding/binary" + "fmt" + "net" + "os" + "os/signal" + "strconv" + "strings" + "syscall" + "time" + + "github.com/sirupsen/logrus" + "gopkg.in/yaml.v2" + "github.com/slackhq/nebula/sshd" +) + +var l = logrus.New() + +type m map[string]interface{} + +func Main(configPath string, configTest bool, buildVersion string) { + l.Out = os.Stdout + l.Formatter = &logrus.TextFormatter{ + FullTimestamp: true, + } + + config := NewConfig() + err := config.Load(configPath) + if err != nil { + l.WithError(err).Error("Failed to load config") + os.Exit(1) + } + + // Print the config if in test, the exit comes later + if configTest { + b, err := yaml.Marshal(config.Settings) + if err != nil { + l.Println(err) + os.Exit(1) + } + l.Println(string(b)) + } + + err = configLogger(config) + if err != nil { + l.WithError(err).Error("Failed to configure the logger") + } + + config.RegisterReloadCallback(func(c *Config) { + err := configLogger(c) + if err != nil { + l.WithError(err).Error("Failed to configure the logger") + } + }) + + // trustedCAs is currently a global, so loadCA operates on that global directly + trustedCAs, err = loadCAFromConfig(config) + if err != nil { + //The errors coming out of loadCA are already nicely formatted + l.Fatal(err) + } + l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints") + + cs, err := NewCertStateFromConfig(config) + if err != nil { + //The errors coming out of NewCertStateFromConfig are already nicely formatted + l.Fatal(err) + } + l.WithField("cert", cs.certificate).Debug("Client nebula certificate") + + fw, err := NewFirewallFromConfig(cs.certificate, config) + if err != nil { + l.Fatal("Error while loading firewall rules: ", err) + } + l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") + + // TODO: make sure mask is 4 bytes + tunCidr := cs.certificate.Details.Ips[0] + routes, err := parseRoutes(config, tunCidr) + if err != nil { + l.WithError(err).Fatal("Could not parse tun.routes") + } + + ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) + wireSSHReload(ssh, config) + if config.GetBool("sshd.enabled", false) { + err = configSSH(ssh, config) + if err != nil { + l.WithError(err).Fatal("Error while configuring the sshd") + } + } + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // All non system modifying configuration consumption should live above this line + // tun config, listeners, anything modifying the computer should be below + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + if configTest { + os.Exit(0) + } + + config.CatchHUP() + + // set up our tun dev + tun, err := newTun( + config.GetString("tun.dev", ""), + tunCidr, + config.GetInt("tun.mtu", 1300), + routes, + config.GetInt("tun.tx_queue", 500), + ) + if err != nil { + l.Fatal(err) + } + + // set up our UDP listener + udpQueues := config.GetInt("listen.routines", 1) + udpServer, err := NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1) + if err != nil { + l.Fatal(err) + } + udpServer.reloadConfig(config) + + // Set up my internal host map + var preferredRanges []*net.IPNet + rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{}) + // First, check if 'preferred_ranges' is set and fallback to 'local_range' + if len(rawPreferredRanges) > 0 { + for _, rawPreferredRange := range rawPreferredRanges { + _, preferredRange, err := net.ParseCIDR(rawPreferredRange) + if err != nil { + l.Fatal(err) + } + preferredRanges = append(preferredRanges, preferredRange) + } + } + + // local_range was superseded by preferred_ranges. If it is still present, + // merge the local_range setting into preferred_ranges. We will probably + // deprecate local_range and remove in the future. + rawLocalRange := config.GetString("local_range", "") + if rawLocalRange != "" { + _, localRange, err := net.ParseCIDR(rawLocalRange) + if err != nil { + l.Fatal(err) + } + + // Check if the entry for local_range was already specified in + // preferred_ranges. Don't put it into the slice twice if so. + var found bool + for _, r := range preferredRanges { + if r.String() == localRange.String() { + found = true + break + } + } + if !found { + preferredRanges = append(preferredRanges, localRange) + } + } + + hostMap := NewHostMap("main", tunCidr, preferredRanges) + hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0")))) + l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created") + + /* + config.SetDefault("promoter.interval", 10) + go hostMap.Promoter(config.GetInt("promoter.interval")) + */ + + punchy := config.GetBool("punchy", false) + if punchy == true { + l.Info("UDP hole punching enabled") + go hostMap.Punchy(udpServer) + } + + port := config.GetInt("listen.port", 0) + // If port is dynamic, discover it + if port == 0 { + uPort, err := udpServer.LocalAddr() + if err != nil { + l.WithError(err).Fatal("Failed to get listening port") + } + port = int(uPort.Port) + } + + punchBack := config.GetBool("punch_back", false) + amLighthouse := config.GetBool("lighthouse.am_lighthouse", false) + serveDns := config.GetBool("lighthouse.serve_dns", false) + lightHouse := NewLightHouse( + amLighthouse, + ip2int(tunCidr.IP), + config.GetStringSlice("lighthouse.hosts", []string{}), + //TODO: change to a duration + config.GetInt("lighthouse.interval", 10), + port, + udpServer, + punchBack, + ) + + if amLighthouse && serveDns { + l.Debugln("Starting dns server") + go dnsMain(hostMap) + } + + for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) { + vpnIp := net.ParseIP(fmt.Sprintf("%v", k)) + vals, ok := v.([]interface{}) + if ok { + for _, v := range vals { + parts := strings.Split(fmt.Sprintf("%v", v), ":") + addr, err := net.ResolveIPAddr("ip", parts[0]) + if err == nil { + ip := addr.IP + port, err := strconv.Atoi(parts[1]) + if err != nil { + l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v) + } + lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true) + } + } + } else { + //TODO: make this all a helper + parts := strings.Split(fmt.Sprintf("%v", v), ":") + addr, err := net.ResolveIPAddr("ip", parts[0]) + if err == nil { + ip := addr.IP + port, err := strconv.Atoi(parts[1]) + if err != nil { + l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v) + } + lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true) + } + } + } + + handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpServer) + + handshakeMACKey := config.GetString("handshake_mac.key", "") + handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{}) + + checkInterval := config.GetInt("timers.connection_alive_interval", 5) + pendingDeletionInterval := config.GetInt("timers.pending_deletion_interval", 10) + ifConfig := &InterfaceConfig{ + HostMap: hostMap, + Inside: tun, + Outside: udpServer, + certState: cs, + Cipher: config.GetString("cipher", "aes"), + Firewall: fw, + ServeDns: serveDns, + HandshakeManager: handshakeManager, + lightHouse: lightHouse, + checkInterval: checkInterval, + pendingDeletionInterval: pendingDeletionInterval, + handshakeMACKey: handshakeMACKey, + handshakeAcceptedMACKeys: handshakeAcceptedMACKeys, + DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false), + DropMulticast: config.GetBool("tun.drop_multicast", false), + UDPBatchSize: config.GetInt("listen.batch", 64), + } + + switch ifConfig.Cipher { + case "aes": + noiseEndiannes = binary.BigEndian + case "chachapoly": + noiseEndiannes = binary.LittleEndian + default: + l.Fatalf("Unknown cipher: %v", ifConfig.Cipher) + } + + ifce, err := NewInterface(ifConfig) + if err != nil { + l.Fatal(err) + } + + ifce.RegisterConfigChangeCallbacks(config) + + go handshakeManager.Run(ifce) + go lightHouse.LhUpdateWorker(ifce) + + err = startStats(config) + if err != nil { + l.Fatal(err) + } + + //TODO: check if we _should_ be emitting stats + go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10)) + + attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) + ifce.Run(config.GetInt("tun.routines", 1), udpQueues, buildVersion) + + // Just sit here and be friendly, main thread. + shutdownBlock(ifce) +} + +func shutdownBlock(ifce *Interface) { + var sigChan = make(chan os.Signal) + signal.Notify(sigChan, syscall.SIGTERM) + signal.Notify(sigChan, syscall.SIGINT) + + sig := <-sigChan + l.WithField("signal", sig).Info("Caught signal, shutting down") + + //TODO: stop tun and udp routines, the lock on hostMap does effectively does that though + //TODO: this is probably better as a function in ConnectionManager or HostMap directly + ifce.hostMap.Lock() + for _, h := range ifce.hostMap.Hosts { + if h.ConnectionState.ready { + ifce.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) + l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote). + Debug("Sending close tunnel message") + } + } + ifce.hostMap.Unlock() + + l.WithField("signal", sig).Info("Goodbye") + os.Exit(0) +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..2808317 --- /dev/null +++ b/main_test.go @@ -0,0 +1 @@ +package nebula diff --git a/metadata.go b/metadata.go new file mode 100644 index 0000000..7ffcc5e --- /dev/null +++ b/metadata.go @@ -0,0 +1,18 @@ +package nebula + +/* + +import ( + proto "github.com/golang/protobuf/proto" +) + +func HandleMetaProto(p []byte) { + m := &NebulaMeta{} + err := proto.Unmarshal(p, m) + if err != nil { + l.Debugf("problem unmarshaling meta message: %s", err) + } + //fmt.Println(m) +} + +*/ diff --git a/nebula.pb.go b/nebula.pb.go new file mode 100644 index 0000000..22f7498 --- /dev/null +++ b/nebula.pb.go @@ -0,0 +1,457 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: nebula.proto + +package nebula + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type NebulaMeta_MessageType int32 + +const ( + NebulaMeta_None NebulaMeta_MessageType = 0 + NebulaMeta_HostQuery NebulaMeta_MessageType = 1 + NebulaMeta_HostQueryReply NebulaMeta_MessageType = 2 + NebulaMeta_HostUpdateNotification NebulaMeta_MessageType = 3 + NebulaMeta_HostMovedNotification NebulaMeta_MessageType = 4 + NebulaMeta_HostPunchNotification NebulaMeta_MessageType = 5 + NebulaMeta_HostWhoami NebulaMeta_MessageType = 6 + NebulaMeta_HostWhoamiReply NebulaMeta_MessageType = 7 + NebulaMeta_PathCheck NebulaMeta_MessageType = 8 + NebulaMeta_PathCheckReply NebulaMeta_MessageType = 9 +) + +var NebulaMeta_MessageType_name = map[int32]string{ + 0: "None", + 1: "HostQuery", + 2: "HostQueryReply", + 3: "HostUpdateNotification", + 4: "HostMovedNotification", + 5: "HostPunchNotification", + 6: "HostWhoami", + 7: "HostWhoamiReply", + 8: "PathCheck", + 9: "PathCheckReply", +} + +var NebulaMeta_MessageType_value = map[string]int32{ + "None": 0, + "HostQuery": 1, + "HostQueryReply": 2, + "HostUpdateNotification": 3, + "HostMovedNotification": 4, + "HostPunchNotification": 5, + "HostWhoami": 6, + "HostWhoamiReply": 7, + "PathCheck": 8, + "PathCheckReply": 9, +} + +func (x NebulaMeta_MessageType) String() string { + return proto.EnumName(NebulaMeta_MessageType_name, int32(x)) +} + +func (NebulaMeta_MessageType) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{0, 0} +} + +type NebulaPing_MessageType int32 + +const ( + NebulaPing_Ping NebulaPing_MessageType = 0 + NebulaPing_Reply NebulaPing_MessageType = 1 +) + +var NebulaPing_MessageType_name = map[int32]string{ + 0: "Ping", + 1: "Reply", +} + +var NebulaPing_MessageType_value = map[string]int32{ + "Ping": 0, + "Reply": 1, +} + +func (x NebulaPing_MessageType) String() string { + return proto.EnumName(NebulaPing_MessageType_name, int32(x)) +} + +func (NebulaPing_MessageType) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{3, 0} +} + +type NebulaMeta struct { + Type NebulaMeta_MessageType `protobuf:"varint,1,opt,name=Type,json=type,proto3,enum=nebula.NebulaMeta_MessageType" json:"Type,omitempty"` + Details *NebulaMetaDetails `protobuf:"bytes,2,opt,name=Details,json=details,proto3" json:"Details,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *NebulaMeta) Reset() { *m = NebulaMeta{} } +func (m *NebulaMeta) String() string { return proto.CompactTextString(m) } +func (*NebulaMeta) ProtoMessage() {} +func (*NebulaMeta) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{0} +} + +func (m *NebulaMeta) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_NebulaMeta.Unmarshal(m, b) +} +func (m *NebulaMeta) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_NebulaMeta.Marshal(b, m, deterministic) +} +func (m *NebulaMeta) XXX_Merge(src proto.Message) { + xxx_messageInfo_NebulaMeta.Merge(m, src) +} +func (m *NebulaMeta) XXX_Size() int { + return xxx_messageInfo_NebulaMeta.Size(m) +} +func (m *NebulaMeta) XXX_DiscardUnknown() { + xxx_messageInfo_NebulaMeta.DiscardUnknown(m) +} + +var xxx_messageInfo_NebulaMeta proto.InternalMessageInfo + +func (m *NebulaMeta) GetType() NebulaMeta_MessageType { + if m != nil { + return m.Type + } + return NebulaMeta_None +} + +func (m *NebulaMeta) GetDetails() *NebulaMetaDetails { + if m != nil { + return m.Details + } + return nil +} + +type NebulaMetaDetails struct { + VpnIp uint32 `protobuf:"varint,1,opt,name=VpnIp,json=vpnIp,proto3" json:"VpnIp,omitempty"` + IpAndPorts []*IpAndPort `protobuf:"bytes,2,rep,name=IpAndPorts,json=ipAndPorts,proto3" json:"IpAndPorts,omitempty"` + Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *NebulaMetaDetails) Reset() { *m = NebulaMetaDetails{} } +func (m *NebulaMetaDetails) String() string { return proto.CompactTextString(m) } +func (*NebulaMetaDetails) ProtoMessage() {} +func (*NebulaMetaDetails) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{1} +} + +func (m *NebulaMetaDetails) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_NebulaMetaDetails.Unmarshal(m, b) +} +func (m *NebulaMetaDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_NebulaMetaDetails.Marshal(b, m, deterministic) +} +func (m *NebulaMetaDetails) XXX_Merge(src proto.Message) { + xxx_messageInfo_NebulaMetaDetails.Merge(m, src) +} +func (m *NebulaMetaDetails) XXX_Size() int { + return xxx_messageInfo_NebulaMetaDetails.Size(m) +} +func (m *NebulaMetaDetails) XXX_DiscardUnknown() { + xxx_messageInfo_NebulaMetaDetails.DiscardUnknown(m) +} + +var xxx_messageInfo_NebulaMetaDetails proto.InternalMessageInfo + +func (m *NebulaMetaDetails) GetVpnIp() uint32 { + if m != nil { + return m.VpnIp + } + return 0 +} + +func (m *NebulaMetaDetails) GetIpAndPorts() []*IpAndPort { + if m != nil { + return m.IpAndPorts + } + return nil +} + +func (m *NebulaMetaDetails) GetCounter() uint32 { + if m != nil { + return m.Counter + } + return 0 +} + +type IpAndPort struct { + Ip uint32 `protobuf:"varint,1,opt,name=Ip,json=ip,proto3" json:"Ip,omitempty"` + Port uint32 `protobuf:"varint,2,opt,name=Port,json=port,proto3" json:"Port,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *IpAndPort) Reset() { *m = IpAndPort{} } +func (m *IpAndPort) String() string { return proto.CompactTextString(m) } +func (*IpAndPort) ProtoMessage() {} +func (*IpAndPort) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{2} +} + +func (m *IpAndPort) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_IpAndPort.Unmarshal(m, b) +} +func (m *IpAndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_IpAndPort.Marshal(b, m, deterministic) +} +func (m *IpAndPort) XXX_Merge(src proto.Message) { + xxx_messageInfo_IpAndPort.Merge(m, src) +} +func (m *IpAndPort) XXX_Size() int { + return xxx_messageInfo_IpAndPort.Size(m) +} +func (m *IpAndPort) XXX_DiscardUnknown() { + xxx_messageInfo_IpAndPort.DiscardUnknown(m) +} + +var xxx_messageInfo_IpAndPort proto.InternalMessageInfo + +func (m *IpAndPort) GetIp() uint32 { + if m != nil { + return m.Ip + } + return 0 +} + +func (m *IpAndPort) GetPort() uint32 { + if m != nil { + return m.Port + } + return 0 +} + +type NebulaPing struct { + Type NebulaPing_MessageType `protobuf:"varint,1,opt,name=Type,json=type,proto3,enum=nebula.NebulaPing_MessageType" json:"Type,omitempty"` + Time uint64 `protobuf:"varint,2,opt,name=Time,json=time,proto3" json:"Time,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *NebulaPing) Reset() { *m = NebulaPing{} } +func (m *NebulaPing) String() string { return proto.CompactTextString(m) } +func (*NebulaPing) ProtoMessage() {} +func (*NebulaPing) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{3} +} + +func (m *NebulaPing) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_NebulaPing.Unmarshal(m, b) +} +func (m *NebulaPing) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_NebulaPing.Marshal(b, m, deterministic) +} +func (m *NebulaPing) XXX_Merge(src proto.Message) { + xxx_messageInfo_NebulaPing.Merge(m, src) +} +func (m *NebulaPing) XXX_Size() int { + return xxx_messageInfo_NebulaPing.Size(m) +} +func (m *NebulaPing) XXX_DiscardUnknown() { + xxx_messageInfo_NebulaPing.DiscardUnknown(m) +} + +var xxx_messageInfo_NebulaPing proto.InternalMessageInfo + +func (m *NebulaPing) GetType() NebulaPing_MessageType { + if m != nil { + return m.Type + } + return NebulaPing_Ping +} + +func (m *NebulaPing) GetTime() uint64 { + if m != nil { + return m.Time + } + return 0 +} + +type NebulaHandshake struct { + Details *NebulaHandshakeDetails `protobuf:"bytes,1,opt,name=Details,json=details,proto3" json:"Details,omitempty"` + Hmac []byte `protobuf:"bytes,2,opt,name=Hmac,json=hmac,proto3" json:"Hmac,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *NebulaHandshake) Reset() { *m = NebulaHandshake{} } +func (m *NebulaHandshake) String() string { return proto.CompactTextString(m) } +func (*NebulaHandshake) ProtoMessage() {} +func (*NebulaHandshake) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{4} +} + +func (m *NebulaHandshake) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_NebulaHandshake.Unmarshal(m, b) +} +func (m *NebulaHandshake) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_NebulaHandshake.Marshal(b, m, deterministic) +} +func (m *NebulaHandshake) XXX_Merge(src proto.Message) { + xxx_messageInfo_NebulaHandshake.Merge(m, src) +} +func (m *NebulaHandshake) XXX_Size() int { + return xxx_messageInfo_NebulaHandshake.Size(m) +} +func (m *NebulaHandshake) XXX_DiscardUnknown() { + xxx_messageInfo_NebulaHandshake.DiscardUnknown(m) +} + +var xxx_messageInfo_NebulaHandshake proto.InternalMessageInfo + +func (m *NebulaHandshake) GetDetails() *NebulaHandshakeDetails { + if m != nil { + return m.Details + } + return nil +} + +func (m *NebulaHandshake) GetHmac() []byte { + if m != nil { + return m.Hmac + } + return nil +} + +type NebulaHandshakeDetails struct { + Cert []byte `protobuf:"bytes,1,opt,name=Cert,json=cert,proto3" json:"Cert,omitempty"` + InitiatorIndex uint32 `protobuf:"varint,2,opt,name=InitiatorIndex,json=initiatorIndex,proto3" json:"InitiatorIndex,omitempty"` + ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,json=responderIndex,proto3" json:"ResponderIndex,omitempty"` + Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,json=cookie,proto3" json:"Cookie,omitempty"` + Time uint64 `protobuf:"varint,5,opt,name=Time,json=time,proto3" json:"Time,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} } +func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) } +func (*NebulaHandshakeDetails) ProtoMessage() {} +func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{5} +} + +func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_NebulaHandshakeDetails.Unmarshal(m, b) +} +func (m *NebulaHandshakeDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_NebulaHandshakeDetails.Marshal(b, m, deterministic) +} +func (m *NebulaHandshakeDetails) XXX_Merge(src proto.Message) { + xxx_messageInfo_NebulaHandshakeDetails.Merge(m, src) +} +func (m *NebulaHandshakeDetails) XXX_Size() int { + return xxx_messageInfo_NebulaHandshakeDetails.Size(m) +} +func (m *NebulaHandshakeDetails) XXX_DiscardUnknown() { + xxx_messageInfo_NebulaHandshakeDetails.DiscardUnknown(m) +} + +var xxx_messageInfo_NebulaHandshakeDetails proto.InternalMessageInfo + +func (m *NebulaHandshakeDetails) GetCert() []byte { + if m != nil { + return m.Cert + } + return nil +} + +func (m *NebulaHandshakeDetails) GetInitiatorIndex() uint32 { + if m != nil { + return m.InitiatorIndex + } + return 0 +} + +func (m *NebulaHandshakeDetails) GetResponderIndex() uint32 { + if m != nil { + return m.ResponderIndex + } + return 0 +} + +func (m *NebulaHandshakeDetails) GetCookie() uint64 { + if m != nil { + return m.Cookie + } + return 0 +} + +func (m *NebulaHandshakeDetails) GetTime() uint64 { + if m != nil { + return m.Time + } + return 0 +} + +func init() { + proto.RegisterEnum("nebula.NebulaMeta_MessageType", NebulaMeta_MessageType_name, NebulaMeta_MessageType_value) + proto.RegisterEnum("nebula.NebulaPing_MessageType", NebulaPing_MessageType_name, NebulaPing_MessageType_value) + proto.RegisterType((*NebulaMeta)(nil), "nebula.NebulaMeta") + proto.RegisterType((*NebulaMetaDetails)(nil), "nebula.NebulaMetaDetails") + proto.RegisterType((*IpAndPort)(nil), "nebula.IpAndPort") + proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing") + proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake") + proto.RegisterType((*NebulaHandshakeDetails)(nil), "nebula.NebulaHandshakeDetails") +} + +func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } + +var fileDescriptor_2d65afa7693df5ef = []byte{ + // 491 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x53, 0x41, 0x6f, 0xda, 0x4c, + 0x10, 0x8d, 0x61, 0x81, 0x30, 0x80, 0xe3, 0xcc, 0xf7, 0x15, 0x91, 0x1e, 0xaa, 0xc8, 0x87, 0x8a, + 0x13, 0x55, 0xc9, 0xa5, 0xd7, 0x8a, 0x1e, 0xe0, 0x00, 0xa2, 0x56, 0xda, 0x1e, 0xab, 0x8d, 0x3d, + 0x8d, 0x57, 0xe0, 0xdd, 0x95, 0xbd, 0xa0, 0xf0, 0x8f, 0xfa, 0x63, 0x7a, 0xec, 0x0f, 0xaa, 0x76, + 0x0d, 0xa6, 0x84, 0xa8, 0xb7, 0x7d, 0xf3, 0xde, 0xcc, 0x8e, 0xdf, 0x5b, 0x43, 0x57, 0xd2, 0xc3, + 0x66, 0xcd, 0x47, 0x3a, 0x57, 0x46, 0x61, 0xb3, 0x44, 0xe1, 0xaf, 0x1a, 0xc0, 0xc2, 0x1d, 0xe7, + 0x64, 0x38, 0x8e, 0x81, 0xdd, 0xef, 0x34, 0x0d, 0xbc, 0x5b, 0x6f, 0xe8, 0x8f, 0xdf, 0x8c, 0xf6, + 0x3d, 0x47, 0xc5, 0x68, 0x4e, 0x45, 0xc1, 0x1f, 0xc9, 0xaa, 0x22, 0x66, 0x76, 0x9a, 0xf0, 0x0e, + 0x5a, 0x9f, 0xc8, 0x70, 0xb1, 0x2e, 0x06, 0xb5, 0x5b, 0x6f, 0xd8, 0x19, 0xdf, 0x9c, 0xb7, 0xed, + 0x05, 0x51, 0x2b, 0x29, 0x0f, 0xe1, 0x6f, 0x0f, 0x3a, 0x7f, 0x8d, 0xc2, 0x4b, 0x60, 0x0b, 0x25, + 0x29, 0xb8, 0xc0, 0x1e, 0xb4, 0xa7, 0xaa, 0x30, 0x9f, 0x37, 0x94, 0xef, 0x02, 0x0f, 0x11, 0xfc, + 0x0a, 0x46, 0xa4, 0xd7, 0xbb, 0xa0, 0x86, 0xaf, 0xa1, 0x6f, 0x6b, 0x5f, 0x74, 0xc2, 0x0d, 0x2d, + 0x94, 0x11, 0x3f, 0x44, 0xcc, 0x8d, 0x50, 0x32, 0xa8, 0xe3, 0x0d, 0xbc, 0xb2, 0xdc, 0x5c, 0x6d, + 0x29, 0x39, 0xa1, 0xd8, 0x81, 0x5a, 0x6e, 0x64, 0x9c, 0x9e, 0x50, 0x0d, 0xf4, 0x01, 0x2c, 0xf5, + 0x2d, 0x55, 0x3c, 0x13, 0x41, 0x13, 0xff, 0x83, 0xab, 0x23, 0x2e, 0xaf, 0x6d, 0xd9, 0xcd, 0x96, + 0xdc, 0xa4, 0x93, 0x94, 0xe2, 0x55, 0x70, 0x69, 0x37, 0xab, 0x60, 0x29, 0x69, 0x87, 0x5b, 0xb8, + 0x3e, 0xfb, 0x68, 0xfc, 0x1f, 0x1a, 0x5f, 0xb5, 0x9c, 0x69, 0xe7, 0x6a, 0x2f, 0x6a, 0x6c, 0x2d, + 0xc0, 0xf7, 0x00, 0x33, 0xfd, 0x51, 0x26, 0x4b, 0x95, 0x1b, 0xeb, 0x5c, 0x7d, 0xd8, 0x19, 0x5f, + 0x1f, 0x9c, 0xab, 0x98, 0x08, 0x44, 0x25, 0xc2, 0x01, 0xb4, 0x62, 0xb5, 0x91, 0x86, 0xf2, 0x41, + 0xdd, 0x8d, 0x3a, 0xc0, 0xf0, 0x1d, 0xb4, 0xab, 0x16, 0xf4, 0xa1, 0x56, 0x5d, 0x56, 0x13, 0x1a, + 0x11, 0x98, 0xad, 0xbb, 0x74, 0x7a, 0x11, 0xd3, 0x2a, 0x37, 0xe1, 0xd3, 0x21, 0xf6, 0xa5, 0x90, + 0x8f, 0xff, 0x8e, 0xdd, 0x2a, 0x5e, 0x88, 0x1d, 0x81, 0xdd, 0x8b, 0x8c, 0xdc, 0x54, 0x16, 0x31, + 0x23, 0x32, 0x0a, 0xc3, 0xb3, 0x50, 0x6d, 0x73, 0x70, 0x81, 0x6d, 0x68, 0x94, 0x16, 0x79, 0xe1, + 0x77, 0xb8, 0x2a, 0xe7, 0x4e, 0xb9, 0x4c, 0x8a, 0x94, 0xaf, 0x08, 0x3f, 0x1c, 0x5f, 0x90, 0xe7, + 0x5e, 0xd0, 0xb3, 0x0d, 0x2a, 0xe5, 0xf3, 0x67, 0x64, 0x97, 0x98, 0x66, 0x3c, 0x76, 0x4b, 0x74, + 0x23, 0x96, 0x66, 0x3c, 0x0e, 0x7f, 0x7a, 0xd0, 0x7f, 0xb9, 0xcf, 0xca, 0x27, 0x94, 0x1b, 0x77, + 0x4b, 0x37, 0x62, 0x31, 0xe5, 0x06, 0xdf, 0x82, 0x3f, 0x93, 0xc2, 0x08, 0x6e, 0x54, 0x3e, 0x93, + 0x09, 0x3d, 0xed, 0x7d, 0xf2, 0xc5, 0x49, 0xd5, 0xea, 0x22, 0x2a, 0xb4, 0x92, 0x09, 0xed, 0x75, + 0x65, 0x06, 0x7e, 0x7e, 0x52, 0xc5, 0x3e, 0x34, 0x27, 0x4a, 0xad, 0x04, 0x0d, 0x98, 0x73, 0xa6, + 0x19, 0x3b, 0x54, 0xf9, 0xd5, 0x38, 0xfa, 0xf5, 0xd0, 0x74, 0x3f, 0xe3, 0xdd, 0x9f, 0x00, 0x00, + 0x00, 0xff, 0xff, 0x65, 0xc6, 0x25, 0x44, 0x9c, 0x03, 0x00, 0x00, +} diff --git a/nebula.proto b/nebula.proto new file mode 100644 index 0000000..75994c0 --- /dev/null +++ b/nebula.proto @@ -0,0 +1,59 @@ +syntax = "proto3"; +package nebula; + +message NebulaMeta { + enum MessageType { + None = 0; + HostQuery = 1; + HostQueryReply = 2; + HostUpdateNotification = 3; + HostMovedNotification = 4; + HostPunchNotification = 5; + HostWhoami = 6; + HostWhoamiReply = 7; + PathCheck = 8; + PathCheckReply = 9; + + } + + MessageType Type = 1; + NebulaMetaDetails Details = 2; +} + + +message NebulaMetaDetails { + + uint32 VpnIp = 1; + repeated IpAndPort IpAndPorts = 2; + uint32 counter = 3; +} + +message IpAndPort { + uint32 Ip = 1; + uint32 Port = 2; +} + + +message NebulaPing { + enum MessageType { + Ping = 0; + Reply = 1; + } + + MessageType Type = 1; + uint64 Time = 2; +} + +message NebulaHandshake { + NebulaHandshakeDetails Details = 1; + bytes Hmac = 2; +} + +message NebulaHandshakeDetails { + bytes Cert = 1; + uint32 InitiatorIndex = 2; + uint32 ResponderIndex = 3; + uint64 Cookie = 4; + uint64 Time = 5; +} + diff --git a/noise.go b/noise.go new file mode 100644 index 0000000..29e80ef --- /dev/null +++ b/noise.go @@ -0,0 +1,60 @@ +package nebula + +import ( + "crypto/cipher" + "encoding/binary" + "errors" + + "github.com/flynn/noise" +) + +type endiannes interface { + PutUint64(b []byte, v uint64) +} + +var noiseEndiannes endiannes = binary.BigEndian + +type NebulaCipherState struct { + c noise.Cipher + //k [32]byte + //n uint64 +} + +func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState { + return &NebulaCipherState{c: s.Cipher()} + +} + +func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { + if s != nil { + // TODO: Is this okay now that we have made messageCounter atomic? + // Alternative may be to split the counter space into ranges + //if n <= s.n { + // return nil, errors.New("CRITICAL: a duplicate counter value was used") + //} + //s.n = n + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + noiseEndiannes.PutUint64(nb[4:], n) + out = s.c.(cipher.AEAD).Seal(out, nb, plaintext, ad) + //l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext)) + return out, nil + } else { + return nil, errors.New("no cipher state available to encrypt") + } +} + +func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) { + if s != nil { + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + noiseEndiannes.PutUint64(nb[4:], n) + return s.c.(cipher.AEAD).Open(out, nb, ciphertext, ad) + } else { + return []byte{}, nil + } +} diff --git a/outside.go b/outside.go new file mode 100644 index 0000000..968b8d2 --- /dev/null +++ b/outside.go @@ -0,0 +1,410 @@ +package nebula + +import ( + "encoding/binary" + + "github.com/flynn/noise" + "github.com/golang/protobuf/proto" + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" + // "github.com/google/gopacket" + // "github.com/google/gopacket/layers" + // "encoding/binary" + "errors" + "fmt" + "time" + + "golang.org/x/net/ipv4" +) + +const ( + minFwPacketLen = 4 +) + +func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, nb []byte) { + err := header.Parse(packet) + if err != nil { + // TODO: best if we return this and let caller log + // TODO: Might be better to send the literal []byte("holepunch") packet and ignore that? + // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors + if len(packet) > 1 { + l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err) + } + return + } + + //l.Error("in packet ", header, packet[HeaderLen:]) + + // verify if we've seen this index before, otherwise respond to the handshake initiation + hostinfo, err := f.hostMap.QueryIndex(header.RemoteIndex) + + var ci *ConnectionState + if err == nil { + ci = hostinfo.ConnectionState + } + + switch header.Type { + case message: + if !f.handleEncrypted(ci, addr, header) { + return + } + + f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb) + + // Fallthrough to the bottom to record incoming traffic + + case lightHouse: + if !f.handleEncrypted(ci, addr, header) { + return + } + + d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb) + if err != nil { + l.WithError(err).WithField("udpAddr", addr).WithField("vpnIp", IntIp(hostinfo.hostId)). + WithField("packet", packet). + Error("Failed to decrypt lighthouse packet") + + //TODO: maybe after build 64 is out? 06/14/2018 - NB + //f.sendRecvError(net.Addr(addr), header.RemoteIndex) + return + } + + f.lightHouse.HandleRequest(addr, hostinfo.hostId, d, hostinfo.GetCert(), f) + + // Fallthrough to the bottom to record incoming traffic + + case test: + if !f.handleEncrypted(ci, addr, header) { + return + } + + d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb) + if err != nil { + l.WithError(err).WithField("udpAddr", addr).WithField("vpnIp", IntIp(hostinfo.hostId)). + WithField("packet", packet). + Error("Failed to decrypt test packet") + + //TODO: maybe after build 64 is out? 06/14/2018 - NB + //f.sendRecvError(net.Addr(addr), header.RemoteIndex) + return + } + + if header.Subtype == testRequest { + // This testRequest might be from TryPromoteBest, so we should roam + // to the new IP address before responding + f.handleHostRoaming(hostinfo, addr) + f.send(test, testReply, ci, hostinfo, hostinfo.remote, d, nb, out) + } + + // Fallthrough to the bottom to record incoming traffic + + // Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they + // are unauthenticated + + case handshake: + HandleIncomingHandshake(f, addr, packet, header, hostinfo) + return + + case recvError: + // TODO: Remove this with recv_error deprecation + f.handleRecvError(addr, header) + return + + case closeTunnel: + if !f.handleEncrypted(ci, addr, header) { + return + } + + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). + Info("Close tunnel received, tearing down.") + + f.closeTunnel(hostinfo) + return + + default: + l.Debugf("Unexpected packet received from %s", addr) + return + } + + f.handleHostRoaming(hostinfo, addr) + + f.connectionManager.In(hostinfo.hostId) +} + +func (f *Interface) closeTunnel(hostInfo *HostInfo) { + //TODO: this would be better as a single function in ConnectionManager that handled locks appropriately + f.connectionManager.ClearIP(hostInfo.hostId) + f.connectionManager.ClearPendingDeletion(hostInfo.hostId) + f.lightHouse.DeleteVpnIP(hostInfo.hostId) + f.hostMap.DeleteVpnIP(hostInfo.hostId) + f.hostMap.DeleteIndex(hostInfo.localIndexId) +} + +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { + if hostDidRoam(hostinfo.remote, addr) { + if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSupressSeconds*time.Second { + if l.Level >= logrus.DebugLevel { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). + Debugf("Supressing roam back to previous remote for %d seconds", RoamingSupressSeconds) + } + return + } + + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). + Info("Host roamed to new udp ip/port.") + hostinfo.lastRoam = time.Now() + remoteCopy := *hostinfo.remote + hostinfo.lastRoamRemote = &remoteCopy + hostinfo.SetRemote(*addr) + if f.lightHouse.amLighthouse { + f.lightHouse.AddRemote(hostinfo.hostId, addr, false) + } + } + +} + +func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool { + // If connectionstate exists and the replay protector allows, process packet + // Else, send recv errors for 300 seconds after a restart to allow fast reconnection. + if ci == nil || !ci.window.Check(header.MessageCounter) { + f.sendRecvError(addr, header.RemoteIndex) + return false + } + + return true +} + +// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers +func newPacket(data []byte, incoming bool, fp *FirewallPacket) error { + // Do we at least have an ipv4 header worth of data? + if len(data) < ipv4.HeaderLen { + return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen) + } + + // Is it an ipv4 packet? + if int((data[0]>>4)&0x0f) != 4 { + return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f)) + } + + // Adjust our start position based on the advertised ip header length + ihl := int(data[0]&0x0f) << 2 + + // Well formed ip header length? + if ihl < ipv4.HeaderLen { + return fmt.Errorf("packet had an invalid header length: %v", ihl) + } + + // Check if this is the second or further fragment of a fragmented packet. + flagsfrags := binary.BigEndian.Uint16(data[6:8]) + fp.Fragment = (flagsfrags & 0x1FFF) != 0 + + // Firewall handles protocol checks + fp.Protocol = data[9] + + // Accounting for a variable header length, do we have enough data for our src/dst tuples? + minLen := ihl + if !fp.Fragment && fp.Protocol != fwProtoICMP { + minLen += minFwPacketLen + } + if len(data) < minLen { + return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl) + } + + // Firewall packets are locally oriented + if incoming { + fp.RemoteIP = binary.BigEndian.Uint32(data[12:16]) + fp.LocalIP = binary.BigEndian.Uint32(data[16:20]) + if fp.Fragment || fp.Protocol == fwProtoICMP { + fp.RemotePort = 0 + fp.LocalPort = 0 + } else { + fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) + fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) + } + } else { + fp.LocalIP = binary.BigEndian.Uint32(data[12:16]) + fp.RemoteIP = binary.BigEndian.Uint32(data[16:20]) + if fp.Fragment || fp.Protocol == fwProtoICMP { + fp.RemotePort = 0 + fp.LocalPort = 0 + } else { + fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) + } + } + + return nil +} + +func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, header *Header, nb []byte) ([]byte, error) { + var err error + out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], mc, nb) + if err != nil { + return nil, err + } + + if !hostinfo.ConnectionState.window.Update(mc) { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("header", header). + Debugln("dropping out of window packet") + return nil, errors.New("out of window packet") + } + + return out, nil +} + +func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) { + var err error + + // TODO: This breaks subnet routing and needs to also check range of ip subnet + /* + if len(res) > 16 && binary.BigEndian.Uint32(res[12:16]) != ip2int(ci.peerCert.Details.Ips[0].IP) { + l.Debugf("Host %s tried to spoof packet as %s.", ci.peerCert.Details.Ips[0].IP, IntIp(binary.BigEndian.Uint32(res[12:16]))) + } + */ + + out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb) + if err != nil { + l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).Error("Failed to decrypt packet") + //TODO: maybe after build 64 is out? 06/14/2018 - NB + //f.sendRecvError(hostinfo.remote, header.RemoteIndex) + return + } + + err = newPacket(out, true, fwPacket) + if err != nil { + l.WithError(err).WithField("packet", out).WithField("hostInfo", IntIp(hostinfo.hostId)). + Warnf("Error while validating inbound packet") + return + } + + if !hostinfo.ConnectionState.window.Update(messageCounter) { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket). + Debugln("dropping out of window packet") + return + } + + if f.firewall.Drop(out, *fwPacket, true, hostinfo.ConnectionState.peerCert, trustedCAs) { + l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket). + Debugln("dropping inbound packet") + return + } + + f.connectionManager.In(hostinfo.hostId) + err = f.inside.WriteRaw(out) + if err != nil { + l.WithError(err).Error("Failed to write to tun") + } +} + +func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) { + f.metricTxRecvError.Inc(1) + + //TODO: this should be a signed message so we can trust that we should drop the index + b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0) + f.outside.WriteTo(b, endpoint) + if l.Level >= logrus.DebugLevel { + l.WithField("index", index). + WithField("udpAddr", endpoint). + Debug("Recv error sent") + } +} + +func (f *Interface) handleRecvError(addr *udpAddr, h *Header) { + f.metricRxRecvError.Inc(1) + + // This flag is to stop caring about recv_error from old versions + // This should go away when the old version is gone from prod + if l.Level >= logrus.DebugLevel { + l.WithField("index", h.RemoteIndex). + WithField("udpAddr", addr). + Debug("Recv error received") + } + + hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex) + if err != nil { + l.Debugln(err, ": ", h.RemoteIndex) + return + } + + if !hostinfo.RecvErrorExceeded() { + return + } + if hostinfo.remote != nil && hostinfo.remote.String() != addr.String() { + l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) + return + } + + id := hostinfo.localIndexId + host := hostinfo.hostId + // We delete this host from the main hostmap + f.hostMap.DeleteIndex(id) + f.hostMap.DeleteVpnIP(host) + // We also delete it from pending to allow for + // fast reconnect. We must null the connectionstate + // or a counter reuse may happen + hostinfo.ConnectionState = nil + f.handshakeManager.DeleteIndex(id) + f.handshakeManager.DeleteVpnIP(host) +} + +/* +func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *NebulaMeta) { + if ci.eKey != nil { + //TODO: log error? + return + } + + msg, err := proto.Marshal(meta) + if err != nil { + l.Debugln("failed to encode header") + } + + c := ci.messageCounter + b := HeaderEncode(nil, Version, uint8(metadata), 0, hostinfo.remoteIndexId, c) + ci.messageCounter++ + + msg := ci.eKey.EncryptDanger(b, nil, msg, c) + //msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c) + f.outside.WriteTo(msg, endpoint) +} +*/ + +func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*cert.NebulaCertificate, error) { + pk := h.PeerStatic() + + if pk == nil { + return nil, errors.New("no peer static key was present") + } + + if rawCertBytes == nil { + return nil, errors.New("provided payload was empty") + } + + r := &cert.RawNebulaCertificate{} + err := proto.Unmarshal(rawCertBytes, r) + if err != nil { + return nil, fmt.Errorf("error unmarshaling cert: %s", err) + } + + // If the Details are nil, just exit to avoid crashing + if r.Details == nil { + return nil, fmt.Errorf("certificate did not contain any details") + } + + r.Details.PublicKey = pk + recombined, err := proto.Marshal(r) + if err != nil { + return nil, fmt.Errorf("error while recombining certificate: %s", err) + } + + c, _ := cert.UnmarshalNebulaCertificate(recombined) + isValid, err := c.Verify(time.Now(), trustedCAs) + if err != nil { + return c, fmt.Errorf("certificate validation failed: %s", err) + } else if !isValid { + // This case should never happen but here's to defensive programming! + return c, errors.New("certificate validation failed but did not return an error") + } + + return c, nil +} diff --git a/outside_test.go b/outside_test.go new file mode 100644 index 0000000..c4dee27 --- /dev/null +++ b/outside_test.go @@ -0,0 +1,80 @@ +package nebula + +import ( + "github.com/stretchr/testify/assert" + "golang.org/x/net/ipv4" + "net" + "testing" +) + +func Test_newPacket(t *testing.T) { + p := &FirewallPacket{} + + // length fail + err := newPacket([]byte{0, 1}, true, p) + assert.EqualError(t, err, "packet is less than 20 bytes") + + // length fail with ip options + h := ipv4.Header{ + Version: 1, + Len: 100, + Src: net.IPv4(10, 0, 0, 1), + Dst: net.IPv4(10, 0, 0, 2), + Options: []byte{0, 1, 0, 2}, + } + + b, _ := h.Marshal() + err = newPacket(b, true, p) + + assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24") + + // not an ipv4 packet + err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) + assert.EqualError(t, err, "packet is not ipv4, type: 0") + + // invalid ihl + err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) + assert.EqualError(t, err, "packet had an invalid header length: 8") + + // account for variable ip header length - incoming + h = ipv4.Header{ + Version: 1, + Len: 100, + Src: net.IPv4(10, 0, 0, 1), + Dst: net.IPv4(10, 0, 0, 2), + Options: []byte{0, 1, 0, 2}, + Protocol: fwProtoTCP, + } + + b, _ = h.Marshal() + b = append(b, []byte{0, 3, 0, 4}...) + err = newPacket(b, true, p) + + assert.Nil(t, err) + assert.Equal(t, p.Protocol, uint8(fwProtoTCP)) + assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 2))) + assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 1))) + assert.Equal(t, p.RemotePort, uint16(3)) + assert.Equal(t, p.LocalPort, uint16(4)) + + // account for variable ip header length - outgoing + h = ipv4.Header{ + Version: 1, + Protocol: 2, + Len: 100, + Src: net.IPv4(10, 0, 0, 1), + Dst: net.IPv4(10, 0, 0, 2), + Options: []byte{0, 1, 0, 2}, + } + + b, _ = h.Marshal() + b = append(b, []byte{0, 5, 0, 6}...) + err = newPacket(b, false, p) + + assert.Nil(t, err) + assert.Equal(t, p.Protocol, uint8(2)) + assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 1))) + assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 2))) + assert.Equal(t, p.RemotePort, uint16(6)) + assert.Equal(t, p.LocalPort, uint16(5)) +} diff --git a/ssh.go b/ssh.go new file mode 100644 index 0000000..81846b8 --- /dev/null +++ b/ssh.go @@ -0,0 +1,727 @@ +package nebula + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "github.com/sirupsen/logrus" + "io/ioutil" + "net" + "os" + "reflect" + "runtime/pprof" + "github.com/slackhq/nebula/sshd" + "strings" + "syscall" +) + +type sshListHostMapFlags struct { + Json bool + Pretty bool +} + +type sshPrintCertFlags struct { + Json bool + Pretty bool +} + +type sshPrintTunnelFlags struct { + Pretty bool +} + +type sshChangeRemoteFlags struct { + Address string +} + +type sshCloseTunnelFlags struct { + LocalOnly bool +} + +type sshCreateTunnelFlags struct { + Address string +} + +func wireSSHReload(ssh *sshd.SSHServer, c *Config) { + c.RegisterReloadCallback(func(c *Config) { + if c.GetBool("sshd.enabled", false) { + err := configSSH(ssh, c) + if err != nil { + l.WithError(err).Error("Failed to reconfigure the sshd") + ssh.Stop() + } + } else { + ssh.Stop() + } + }) +} + +func configSSH(ssh *sshd.SSHServer, c *Config) error { + //TODO conntrack list + //TODO print firewall rules or hash? + + listen := c.GetString("sshd.listen", "") + if listen == "" { + return fmt.Errorf("sshd.listen must be provided") + } + + port := strings.Split(listen, ":") + if len(port) < 2 { + return fmt.Errorf("sshd.listen does not have a port") + } else if port[1] == "22" { + return fmt.Errorf("sshd.listen can not use port 22") + } + + //TODO: no good way to reload this right now + hostKeyFile := c.GetString("sshd.host_key", "") + if hostKeyFile == "" { + return fmt.Errorf("sshd.host_key must be provided") + } + + hostKeyBytes, err := ioutil.ReadFile(hostKeyFile) + if err != nil { + return fmt.Errorf("error while loading sshd.host_key file: %s", err) + } + + err = ssh.SetHostKey(hostKeyBytes) + if err != nil { + return fmt.Errorf("error while adding sshd.host_key: %s", err) + } + + rawKeys := c.Get("sshd.authorized_users") + keys, ok := rawKeys.([]interface{}) + if ok { + for _, rk := range keys { + kDef, ok := rk.(map[interface{}]interface{}) + if !ok { + l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring") + continue + } + + user, ok := kDef["user"].(string) + if !ok { + l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the user field") + continue + } + + k := kDef["keys"] + switch v := k.(type) { + case string: + err := ssh.AddAuthorizedKey(user, v) + if err != nil { + l.WithError(err).WithField("sshKeyConfig", rk).WithField("sshKey", v).Warn("Failed to authorize key") + continue + } + + case []interface{}: + for _, subK := range v { + sk, ok := subK.(string) + if !ok { + l.WithField("sshKeyConfig", rk).WithField("sshKey", subK).Warn("Did not understand ssh key") + continue + } + + err := ssh.AddAuthorizedKey(user, sk) + if err != nil { + l.WithError(err).WithField("sshKeyConfig", sk).Warn("Failed to authorize key") + continue + } + } + + default: + l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the keys field or was not understood") + } + } + } else { + l.Info("no ssh users to authorize") + } + + if c.GetBool("sshd.enabled", false) { + ssh.Stop() + go ssh.Run(listen) + } else { + ssh.Stop() + } + + return nil +} + +func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) { + ssh.RegisterCommand(&sshd.Command{ + Name: "list-hostmap", + ShortDescription: "List all known previously connected hosts", + Flags: func() (*flag.FlagSet, interface{}) { + fl := flag.NewFlagSet("", flag.ContinueOnError) + s := sshListHostMapFlags{} + fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") + fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") + return fl, &s + }, + Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + return sshListHostMap(hostMap, fs, w) + }, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "list-pending-hostmap", + ShortDescription: "List all handshaking hosts", + Flags: func() (*flag.FlagSet, interface{}) { + fl := flag.NewFlagSet("", flag.ContinueOnError) + s := sshListHostMapFlags{} + fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") + fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") + return fl, &s + }, + Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + return sshListHostMap(pendingHostMap, fs, w) + }, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "list-lighthouse-addrmap", + ShortDescription: "List all lighthouse map entries", + Flags: func() (*flag.FlagSet, interface{}) { + fl := flag.NewFlagSet("", flag.ContinueOnError) + s := sshListHostMapFlags{} + fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") + fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") + return fl, &s + }, + Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + return sshListLighthouseMap(lightHouse, fs, w) + }, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "reload", + ShortDescription: "Reloads configuration from disk, same as sending HUP to the process", + Callback: sshReload, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "start-cpu-profile", + ShortDescription: "Starts a cpu profile and write output to the provided file", + Callback: sshStartCpuProfile, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "stop-cpu-profile", + ShortDescription: "Stops a cpu profile and writes output to the previously provided file", + Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + pprof.StopCPUProfile() + return w.WriteLine("If a CPU profile was running it is now stopped") + }, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "save-heap-profile", + ShortDescription: "Saves a heap profile to the provided path", + Callback: sshGetHeapProfile, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "log-level", + ShortDescription: "Gets or sets the current log level", + Callback: sshLogLevel, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "log-format", + ShortDescription: "Gets or sets the current log format", + Callback: sshLogFormat, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "version", + ShortDescription: "Prints the currently running version of nebula", + Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + return sshVersion(ifce, fs, a, w) + }, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "print-cert", + ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn ip", + Flags: func() (*flag.FlagSet, interface{}) { + fl := flag.NewFlagSet("", flag.ContinueOnError) + s := sshPrintCertFlags{} + fl.BoolVar(&s.Json, "json", false, "outputs as json") + fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") + return fl, &s + }, + Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + return sshPrintCert(ifce, fs, a, w) + }, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "print-tunnel", + ShortDescription: "Prints json details about a tunnel for the provided vpn ip", + Flags: func() (*flag.FlagSet, interface{}) { + fl := flag.NewFlagSet("", flag.ContinueOnError) + s := sshPrintTunnelFlags{} + fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json") + return fl, &s + }, + Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + return sshPrintTunnel(ifce, fs, a, w) + }, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "change-remote", + ShortDescription: "Changes the remote address used in the tunnel for the provided vpn ip", + Flags: func() (*flag.FlagSet, interface{}) { + fl := flag.NewFlagSet("", flag.ContinueOnError) + s := sshChangeRemoteFlags{} + fl.StringVar(&s.Address, "address", "", "The new remote address, ip:port") + return fl, &s + }, + Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + return sshChangeRemote(ifce, fs, a, w) + }, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "close-tunnel", + ShortDescription: "Closes a tunnel for the provided vpn ip", + Flags: func() (*flag.FlagSet, interface{}) { + fl := flag.NewFlagSet("", flag.ContinueOnError) + s := sshCloseTunnelFlags{} + fl.BoolVar(&s.LocalOnly, "local-only", false, "Disables notifying the remote that the tunnel is shutting down") + return fl, &s + }, + Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + return sshCloseTunnel(ifce, fs, a, w) + }, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "create-tunnel", + ShortDescription: "Creates a tunnel for the provided vpn ip and address", + Help: "The lighthouses will be queried for real addresses but you can provide one as well.", + Flags: func() (*flag.FlagSet, interface{}) { + fl := flag.NewFlagSet("", flag.ContinueOnError) + s := sshCreateTunnelFlags{} + fl.StringVar(&s.Address, "address", "", "Optionally provide a real remote address, ip:port ") + return fl, &s + }, + Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + return sshCreateTunnel(ifce, fs, a, w) + }, + }) + + ssh.RegisterCommand(&sshd.Command{ + Name: "query-lighthouse", + ShortDescription: "Query the lighthouses for the provided vpn ip", + Help: "This command is asynchronous. Only currently known udp ips will be printed.", + Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + return sshQueryLighthouse(ifce, fs, a, w) + }, + }) +} + +func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error { + fs, ok := a.(*sshListHostMapFlags) + if !ok { + //TODO: error + return nil + } + + hostMap.RLock() + defer hostMap.RUnlock() + + if fs.Json || fs.Pretty { + js := json.NewEncoder(w.GetWriter()) + if fs.Pretty { + js.SetIndent("", " ") + } + + d := make([]m, len(hostMap.Hosts)) + x := 0 + var h m + for _, v := range hostMap.Hosts { + h = m{ + "vpnIp": int2ip(v.hostId), + "localIndex": v.localIndexId, + "remoteIndex": v.remoteIndexId, + "remoteAddrs": v.RemoteUDPAddrs(), + "cachedPackets": len(v.packetStore), + "cert": v.GetCert(), + } + + if v.ConnectionState != nil { + h["messageCounter"] = v.ConnectionState.messageCounter + } + + d[x] = h + x++ + } + + err := js.Encode(d) + if err != nil { + //TODO + return nil + } + } else { + for i, v := range hostMap.Hosts { + err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(i), v.RemoteUDPAddrs())) + if err != nil { + return err + } + } + } + + return nil +} + +func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error { + fs, ok := a.(*sshListHostMapFlags) + if !ok { + //TODO: error + return nil + } + + lightHouse.RLock() + defer lightHouse.RUnlock() + + if fs.Json || fs.Pretty { + js := json.NewEncoder(w.GetWriter()) + if fs.Pretty { + js.SetIndent("", " ") + } + + d := make([]m, len(lightHouse.addrMap)) + x := 0 + var h m + for vpnIp, v := range lightHouse.addrMap { + ips := make([]string, len(v)) + for i, ip := range v { + ips[i] = ip.String() + } + + h = m{ + "vpnIp": int2ip(vpnIp), + "addrs": ips, + } + + d[x] = h + x++ + } + + err := js.Encode(d) + if err != nil { + //TODO + return nil + } + } else { + for vpnIp, v := range lightHouse.addrMap { + ips := make([]string, len(v)) + for i, ip := range v { + ips[i] = ip.String() + } + err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(vpnIp), ips)) + if err != nil { + return err + } + } + } + + return nil +} + +func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error { + if len(a) == 0 { + err := w.WriteLine("No path to write profile provided") + return err + } + + file, err := os.Create(a[0]) + if err != nil { + err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) + return err + } + + err = pprof.StartCPUProfile(file) + if err != nil { + err = w.WriteLine(fmt.Sprintf("Unable to start cpu profile: %s", err)) + return err + } + + err = w.WriteLine(fmt.Sprintf("Started cpu profile, issue stop-cpu-profile to write the output to %s", a)) + return err +} + +func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { + return w.WriteLine(fmt.Sprintf("%s", ifce.version)) +} + +func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { + if len(a) == 0 { + return w.WriteLine("No vpn ip was provided") + } + + vpnIp := ip2int(net.ParseIP(a[0])) + if vpnIp == 0 { + return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + } + + ips, _ := ifce.lightHouse.Query(vpnIp, ifce) + return json.NewEncoder(w.GetWriter()).Encode(ips) +} + +func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { + flags, ok := fs.(*sshCloseTunnelFlags) + if !ok { + //TODO: error + return nil + } + + if len(a) == 0 { + return w.WriteLine("No vpn ip was provided") + } + + vpnIp := ip2int(net.ParseIP(a[0])) + if vpnIp == 0 { + return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + } + + hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) + if err != nil { + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + } + + if !flags.LocalOnly { + ifce.send( + closeTunnel, + 0, + hostInfo.ConnectionState, + hostInfo, + hostInfo.remote, + []byte{}, + make([]byte, 12, 12), + make([]byte, mtu), + ) + } + + ifce.closeTunnel(hostInfo) + return w.WriteLine("Closed") +} + +func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { + flags, ok := fs.(*sshCreateTunnelFlags) + if !ok { + //TODO: error + return nil + } + + if len(a) == 0 { + return w.WriteLine("No vpn ip was provided") + } + + vpnIp := ip2int(net.ParseIP(a[0])) + if vpnIp == 0 { + return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + } + + hostInfo, _ := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) + if hostInfo != nil { + return w.WriteLine(fmt.Sprintf("Tunnel already exists")) + } + + hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIP(uint32(vpnIp)) + if hostInfo != nil { + return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) + } + + var addr *udpAddr + if flags.Address != "" { + addr = NewUDPAddrFromString(flags.Address) + if addr == nil { + return w.WriteLine("Address could not be parsed") + } + } + + hostInfo = ifce.handshakeManager.AddVpnIP(vpnIp) + if addr != nil { + hostInfo.SetRemote(*addr) + } + ifce.getOrHandshake(vpnIp) + + return w.WriteLine("Created") +} + +func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { + flags, ok := fs.(*sshChangeRemoteFlags) + if !ok { + //TODO: error + return nil + } + + if len(a) == 0 { + return w.WriteLine("No vpn ip was provided") + } + + if flags.Address == "" { + return w.WriteLine("No address was provided") + } + + addr := NewUDPAddrFromString(flags.Address) + if addr == nil { + return w.WriteLine("Address could not be parsed") + } + + vpnIp := ip2int(net.ParseIP(a[0])) + if vpnIp == 0 { + return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + } + + hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) + if err != nil { + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + } + + hostInfo.SetRemote(*addr) + return w.WriteLine("Changed") +} + +func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error { + if len(a) == 0 { + return w.WriteLine("No path to write profile provided") + } + + file, err := os.Create(a[0]) + if err != nil { + err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) + return err + } + + err = pprof.WriteHeapProfile(file) + if err != nil { + err = w.WriteLine(fmt.Sprintf("Unable to write profile: %s", err)) + return err + } + + err = w.WriteLine(fmt.Sprintf("Mem profile created at %s", a)) + return err +} + +func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error { + if len(a) == 0 { + return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) + } + + level, err := logrus.ParseLevel(a[0]) + if err != nil { + return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: %s", a, logrus.AllLevels)) + } + + l.SetLevel(level) + return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) +} + +func sshLogFormat(fs interface{}, a []string, w sshd.StringWriter) error { + if len(a) == 0 { + return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) + } + + logFormat := strings.ToLower(a[0]) + switch logFormat { + case "text": + l.Formatter = &logrus.TextFormatter{} + case "json": + l.Formatter = &logrus.JSONFormatter{} + default: + return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) + } + + return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) +} + +func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { + args, ok := fs.(*sshPrintCertFlags) + if !ok { + //TODO: error + return nil + } + + cert := ifce.certState.certificate + if len(a) > 0 { + vpnIp := ip2int(net.ParseIP(a[0])) + if vpnIp == 0 { + return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + } + + hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) + if err != nil { + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + } + + cert = hostInfo.GetCert() + } + + if args.Json || args.Pretty { + b, err := cert.MarshalJSON() + if err != nil { + //TODO: handle it + return nil + } + + if args.Pretty { + buf := new(bytes.Buffer) + err := json.Indent(buf, b, "", " ") + b = buf.Bytes() + if err != nil { + //TODO: handle it + return nil + } + } + + return w.WriteBytes(b) + } + + return w.WriteLine(cert.String()) +} + +func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { + args, ok := fs.(*sshPrintTunnelFlags) + if !ok { + //TODO: error + return nil + } + + if len(a) == 0 { + return w.WriteLine("No vpn ip was provided") + } + + vpnIp := ip2int(net.ParseIP(a[0])) + if vpnIp == 0 { + return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + } + + hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp)) + if err != nil { + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + } + + enc := json.NewEncoder(w.GetWriter()) + if args.Pretty { + enc.SetIndent("", " ") + } + + return enc.Encode(hostInfo) +} + +func sshReload(fs interface{}, a []string, w sshd.StringWriter) error { + p, err := os.FindProcess(os.Getpid()) + if err != nil { + return w.WriteLine(err.Error()) + //TODO + } + err = p.Signal(syscall.SIGHUP) + if err != nil { + return w.WriteLine(err.Error()) + //TODO + } + return w.WriteLine("HUP sent") +} diff --git a/sshd/command.go b/sshd/command.go new file mode 100644 index 0000000..a1568fe --- /dev/null +++ b/sshd/command.go @@ -0,0 +1,161 @@ +package sshd + +import ( + "errors" + "flag" + "fmt" + "github.com/armon/go-radix" + "sort" + "strings" +) + +// CommandFlags is a function called before help or command execution to parse command line flags +// It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags +type CommandFlags func() (*flag.FlagSet, interface{}) + +// CommandCallback is the function called when your command should execute. +// fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved +// and handled automatically for you. +// a will be any unconsumed arguments, if no Command.Flags was available this will be all the flags passed in. +// w is the writer to use when sending messages back to the client. +// If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user +// where appropriate +type CommandCallback func(fs interface{}, a []string, w StringWriter) error + +type Command struct { + Name string + ShortDescription string + Help string + Flags CommandFlags + Callback CommandCallback +} + +func execCommand(c *Command, args []string, w StringWriter) error { + var ( + fl *flag.FlagSet + fs interface{} + ) + + if c.Flags != nil { + fl, fs = c.Flags() + if fl != nil { + //TODO: handle the error + fl.Parse(args) + args = fl.Args() + } + } + + return c.Callback(fs, args, w) +} + +func dumpCommands(c *radix.Tree, w StringWriter) { + err := w.WriteLine("Available commands:") + if err != nil { + //TODO: log + return + } + + cmds := make([]string, 0) + for _, l := range allCommands(c) { + cmds = append(cmds, fmt.Sprintf("%s - %s", l.Name, l.ShortDescription)) + } + + sort.Strings(cmds) + err = w.Write(strings.Join(cmds, "\n") + "\n\n") + if err != nil { + //TODO: log + } +} + +func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) { + cmd, ok := c.Get(sCmd) + if !ok { + return nil, nil + } + + command, ok := cmd.(*Command) + if !ok { + return nil, errors.New("failed to cast command") + } + + return command, nil +} + +func matchCommand(c *radix.Tree, cmd string) []string { + cmds := make([]string, 0) + c.WalkPrefix(cmd, func(found string, v interface{}) bool { + cmds = append(cmds, found) + return false + }) + sort.Strings(cmds) + return cmds +} + +func allCommands(c *radix.Tree) []*Command { + cmds := make([]*Command, 0) + c.WalkPrefix("", func(found string, v interface{}) bool { + cmd, ok := v.(*Command) + if ok { + cmds = append(cmds, cmd) + } + return false + }) + return cmds +} + +func helpCallback(commands *radix.Tree, a []string, w StringWriter) (err error) { + // Just typed help + if len(a) == 0 { + dumpCommands(commands, w) + return nil + } + + // We are printing a specific commands help text + cmd, err := lookupCommand(commands, a[0]) + if err != nil { + //TODO: handle error + //TODO: message the user + return + } + + if cmd != nil { + err = w.WriteLine(fmt.Sprintf("%s - %s", cmd.Name, cmd.ShortDescription)) + if err != nil { + return err + } + + if cmd.Help != "" { + err = w.WriteLine(fmt.Sprintf(" %s", cmd.Help)) + if err != nil { + return err + } + } + + if cmd.Flags != nil { + fs, _ := cmd.Flags() + if fs != nil { + fs.SetOutput(w.GetWriter()) + fs.PrintDefaults() + } + } + + return nil + } + + err = w.WriteLine("Command not available " + a[0]) + if err != nil { + return err + } + + return nil +} + +func checkHelpArgs(args []string) bool { + for _, a := range args { + if a == "-h" || a == "-help" { + return true + } + } + + return false +} diff --git a/sshd/server.go b/sshd/server.go new file mode 100644 index 0000000..c93f597 --- /dev/null +++ b/sshd/server.go @@ -0,0 +1,182 @@ +package sshd + +import ( + "fmt" + "github.com/armon/go-radix" + "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "net" +) + +type SSHServer struct { + config *ssh.ServerConfig + l *logrus.Entry + + // Map of user -> authorized keys + trustedKeys map[string]map[string]bool + + // List of available commands + helpCommand *Command + commands *radix.Tree + listener net.Listener + conns map[int]*session + counter int +} + +// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen +func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { + s := &SSHServer{ + trustedKeys: make(map[string]map[string]bool), + l: l, + commands: radix.New(), + conns: make(map[int]*session), + } + + s.config = &ssh.ServerConfig{ + PublicKeyCallback: s.matchPubKey, + //TODO: AuthLogCallback: s.authAttempt, + //TODO: version string + ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"), + } + + s.RegisterCommand(&Command{ + Name: "help", + ShortDescription: "prints available commands or help for specific usage info", + Callback: func(a interface{}, args []string, w StringWriter) error { + return helpCallback(s.commands, args, w) + }, + }) + + return s, nil +} + +func (s *SSHServer) SetHostKey(hostPrivateKey []byte) error { + private, err := ssh.ParsePrivateKey(hostPrivateKey) + if err != nil { + return fmt.Errorf("failed to parse private key: %s", err) + } + + s.config.AddHostKey(private) + return nil +} + +func (s *SSHServer) ClearAuthorizedKeys() { + s.trustedKeys = make(map[string]map[string]bool) +} + +// AddAuthorizedKey adds an ssh public key for a user +func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error { + pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey)) + if err != nil { + return err + } + + tk, ok := s.trustedKeys[user] + if !ok { + tk = make(map[string]bool) + s.trustedKeys[user] = tk + } + + tk[string(pk.Marshal())] = true + s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key") + return nil +} + +// RegisterCommand adds a command that can be run by a user, by default only `help` is available +func (s *SSHServer) RegisterCommand(c *Command) { + s.commands.Insert(c.Name, c) +} + +// Run begins listening and accepting connections +func (s *SSHServer) Run(addr string) error { + var err error + s.listener, err = net.Listen("tcp", addr) + if err != nil { + return err + } + + s.l.WithField("sshListener", addr).Info("SSH server is listening") + for { + c, err := s.listener.Accept() + if err != nil { + s.l.WithError(err).Warn("Error in listener, shutting down") + return nil + } + + conn, chans, reqs, err := ssh.NewServerConn(c, s.config) + fp := "" + if conn != nil { + fp = conn.Permissions.Extensions["fp"] + } + + if err != nil { + l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr()) + if conn != nil { + l = l.WithField("sshUser", conn.User()) + conn.Close() + } + if fp != "" { + l = l.WithField("sshFingerprint", fp) + } + l.Warn("failed to handshake") + continue + } + + l := s.l.WithField("sshUser", conn.User()) + l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") + + session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session")) + s.counter++ + counter := s.counter + s.conns[counter] = session + + go ssh.DiscardRequests(reqs) + go func() { + <-session.exitChan + s.l.WithField("id", counter).Debug("closing conn") + delete(s.conns, counter) + }() + } +} + +func (s *SSHServer) Stop() { + for _, c := range s.conns { + c.Close() + } + + if s.listener == nil { + return + } + + err := s.listener.Close() + if err != nil { + s.l.WithError(err).Warn("Failed to close the sshd listener") + return + } + + s.l.Info("SSH server stopped listening") + return +} + +func (s *SSHServer) matchPubKey(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + pk := string(pubKey.Marshal()) + fp := ssh.FingerprintSHA256(pubKey) + + tk, ok := s.trustedKeys[c.User()] + if !ok { + return nil, fmt.Errorf("unknown user %s", c.User()) + } + + _, ok = tk[pk] + if !ok { + return nil, fmt.Errorf("unknown public key for %s (%s)", c.User(), fp) + } + + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "fp": fp, + "user": c.User(), + }, + }, nil +} diff --git a/sshd/session.go b/sshd/session.go new file mode 100644 index 0000000..2ba1c0f --- /dev/null +++ b/sshd/session.go @@ -0,0 +1,182 @@ +package sshd + +import ( + "fmt" + "github.com/anmitsu/go-shlex" + "github.com/armon/go-radix" + "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/terminal" + "sort" + "strings" +) + +type session struct { + l *logrus.Entry + c *ssh.ServerConn + term *terminal.Terminal + commands *radix.Tree + exitChan chan bool +} + +func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session { + s := &session{ + commands: radix.NewFromMap(commands.ToMap()), + l: l, + c: conn, + exitChan: make(chan bool), + } + + s.commands.Insert("logout", &Command{ + Name: "logout", + ShortDescription: "Ends the current session", + Callback: func(a interface{}, args []string, w StringWriter) error { + s.Close() + return nil + }, + }) + + go s.handleChannels(chans) + return s +} + +func (s *session) handleChannels(chans <-chan ssh.NewChannel) { + for newChannel := range chans { + if newChannel.ChannelType() != "session" { + s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type") + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + + channel, requests, err := newChannel.Accept() + if err != nil { + s.l.WithError(err).Warn("could not accept channel") + continue + } + + go s.handleRequests(requests, channel) + } +} + +func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { + for req := range in { + var err error + //TODO: maybe support window sizing? + switch req.Type { + case "shell": + if s.term == nil { + s.term = s.createTerm(channel) + err = req.Reply(true, nil) + } else { + err = req.Reply(false, nil) + } + + case "pty-req": + err = req.Reply(true, nil) + + case "window-change": + err = req.Reply(true, nil) + + case "exec": + var payload = struct{ Value string }{} + cErr := ssh.Unmarshal(req.Payload, &payload) + if cErr == nil { + s.dispatchCommand(payload.Value, &stringWriter{channel}) + } else { + //TODO: log it + } + channel.Close() + return + + default: + s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request") + err = req.Reply(false, nil) + } + + if err != nil { + s.l.WithError(err).Info("Error handling ssh session requests") + s.Close() + return + } + } +} + +func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal { + //TODO: PS1 with nebula cert name + term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ") + term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) { + // key 9 is tab + if key == 9 { + cmds := matchCommand(s.commands, line) + if len(cmds) == 1 { + return cmds[0] + " ", len(cmds[0]) + 1, true + } + + sort.Strings(cmds) + term.Write([]byte(strings.Join(cmds, "\n") + "\n\n")) + } + + return "", 0, false + } + + go s.handleInput(channel) + return term +} + +func (s *session) handleInput(channel ssh.Channel) { + defer s.Close() + w := &stringWriter{w: s.term} + for { + line, err := s.term.ReadLine() + if err != nil { + //TODO: log + break + } + + s.dispatchCommand(line, w) + } +} + +func (s *session) dispatchCommand(line string, w StringWriter) { + args, err := shlex.Split(line, true) + if err != nil { + //todo: LOG IT + return + } + + if len(args) == 0 { + dumpCommands(s.commands, w) + return + } + + c, err := lookupCommand(s.commands, args[0]) + if err != nil { + //TODO: handle the error + return + } + + if c == nil { + err := w.WriteLine(fmt.Sprintf("did not understand: %s", line)) + //TODO: log error + _ = err + + dumpCommands(s.commands, w) + return + } + + if checkHelpArgs(args) { + s.dispatchCommand(fmt.Sprintf("%s %s", "help", c.Name), w) + return + } + + err = execCommand(c, args[1:], w) + if err != nil { + //TODO: log the error + } + return +} + +func (s *session) Close() { + s.c.Close() + s.exitChan <- true +} diff --git a/sshd/writer.go b/sshd/writer.go new file mode 100644 index 0000000..8354c09 --- /dev/null +++ b/sshd/writer.go @@ -0,0 +1,32 @@ +package sshd + +import "io" + +type StringWriter interface { + WriteLine(string) error + Write(string) error + WriteBytes([]byte) error + GetWriter() io.Writer +} + +type stringWriter struct { + w io.Writer +} + +func (w *stringWriter) WriteLine(s string) error { + return w.Write(s + "\n") +} + +func (w *stringWriter) Write(s string) error { + _, err := w.w.Write([]byte(s)) + return err +} + +func (w *stringWriter) WriteBytes(b []byte) error { + _, err := w.w.Write(b) + return err +} + +func (w *stringWriter) GetWriter() io.Writer { + return w.w +} diff --git a/stats.go b/stats.go new file mode 100644 index 0000000..4765f2b --- /dev/null +++ b/stats.go @@ -0,0 +1,89 @@ +package nebula + +import ( + "errors" + "fmt" + "github.com/cyberdelia/go-metrics-graphite" + mp "github.com/nbrownus/go-metrics-prometheus" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/rcrowley/go-metrics" + "log" + "net" + "net/http" + "time" +) + +func startStats(c *Config) error { + mType := c.GetString("stats.type", "") + if mType == "" || mType == "none" { + return nil + } + + interval := c.GetDuration("stats.interval", 0) + if interval == 0 { + return fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", "")) + } + + switch mType { + case "graphite": + startGraphiteStats(interval, c) + case "prometheus": + startPrometheusStats(interval, c) + default: + return fmt.Errorf("stats.type was not understood: %s", mType) + } + + metrics.RegisterDebugGCStats(metrics.DefaultRegistry) + metrics.RegisterRuntimeMemStats(metrics.DefaultRegistry) + + go metrics.CaptureDebugGCStats(metrics.DefaultRegistry, interval) + go metrics.CaptureRuntimeMemStats(metrics.DefaultRegistry, interval) + + return nil +} + +func startGraphiteStats(i time.Duration, c *Config) error { + proto := c.GetString("stats.protocol", "tcp") + host := c.GetString("stats.host", "") + if host == "" { + return errors.New("stats.host can not be empty") + } + + prefix := c.GetString("stats.prefix", "nebula") + addr, err := net.ResolveTCPAddr(proto, host) + if err != nil { + return fmt.Errorf("error while setting up graphite sink: %s", err) + } + + l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr) + go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr) + return nil +} + +func startPrometheusStats(i time.Duration, c *Config) error { + namespace := c.GetString("stats.namespace", "") + subsystem := c.GetString("stats.subsystem", "") + + listen := c.GetString("stats.listen", "") + if listen == "" { + return fmt.Errorf("stats.listen should not be emtpy") + } + + path := c.GetString("stats.path", "") + if path == "" { + return fmt.Errorf("stats.path should not be emtpy") + } + + pr := prometheus.NewRegistry() + pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i) + go pClient.UpdatePrometheusMetrics() + + go func() { + l.Infof("Prometheus stats listening on %s at %s", listen, path) + http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l})) + log.Fatal(http.ListenAndServe(listen, nil)) + }() + + return nil +} diff --git a/timeout.go b/timeout.go new file mode 100644 index 0000000..6e80614 --- /dev/null +++ b/timeout.go @@ -0,0 +1,191 @@ +package nebula + +import ( + "time" +) + +// How many timer objects should be cached +const timerCacheMax = 50000 + +var emptyFWPacket = FirewallPacket{} + +type TimerWheel struct { + // Current tick + current int + + // Cheat on finding the length of the wheel + wheelLen int + + // Last time we ticked, since we are lazy ticking + lastTick *time.Time + + // Durations of a tick and the entire wheel + tickDuration time.Duration + wheelDuration time.Duration + + // The actual wheel which is just a set of singly linked lists, head/tail pointers + wheel []*TimeoutList + + // Singly linked list of items that have timed out of the wheel + expired *TimeoutList + + // Item cache to avoid garbage collect + itemCache *TimeoutItem + itemsCached int +} + +// Represents a tick in the wheel +type TimeoutList struct { + Head *TimeoutItem + Tail *TimeoutItem +} + +// Represents an item within a tick +type TimeoutItem struct { + Packet FirewallPacket + Next *TimeoutItem +} + +// Builds a timer wheel and identifies the tick duration and wheel duration from the provided values +// Purge must be called once per entry to actually remove anything +func NewTimerWheel(min, max time.Duration) *TimerWheel { + //TODO provide an error + //if min >= max { + // return nil + //} + + // Round down and add 1 so we can have the smallest # of ticks in the wheel and still account for a full + // max duration + wLen := int((max / min) + 1) + + tw := TimerWheel{ + wheelLen: wLen, + wheel: make([]*TimeoutList, wLen), + tickDuration: min, + wheelDuration: max, + expired: &TimeoutList{}, + } + + for i := range tw.wheel { + tw.wheel[i] = &TimeoutList{} + } + + return &tw +} + +// Add will add a FirewallPacket to the wheel in it's proper timeout +func (tw *TimerWheel) Add(v FirewallPacket, timeout time.Duration) *TimeoutItem { + // Check and see if we should progress the tick + tw.advance(time.Now()) + + i := tw.findWheel(timeout) + + // Try to fetch off the cache + ti := tw.itemCache + if ti != nil { + tw.itemCache = ti.Next + tw.itemsCached-- + ti.Next = nil + } else { + ti = &TimeoutItem{} + } + + // Relink and return + ti.Packet = v + if tw.wheel[i].Tail == nil { + tw.wheel[i].Head = ti + tw.wheel[i].Tail = ti + } else { + tw.wheel[i].Tail.Next = ti + tw.wheel[i].Tail = ti + } + + return ti +} + +func (tw *TimerWheel) Purge() (FirewallPacket, bool) { + if tw.expired.Head == nil { + return emptyFWPacket, false + } + + ti := tw.expired.Head + tw.expired.Head = ti.Next + + if tw.expired.Head == nil { + tw.expired.Tail = nil + } + + // Clear out the items references + ti.Next = nil + + // Maybe cache it for later + if tw.itemsCached < timerCacheMax { + ti.Next = tw.itemCache + tw.itemCache = ti + tw.itemsCached++ + } + + return ti.Packet, true +} + +// advance will move the wheel forward by proper number of ticks. The caller _should_ lock the wheel before calling this +func (tw *TimerWheel) findWheel(timeout time.Duration) (i int) { + if timeout < tw.tickDuration { + // Can't track anything below the set resolution + timeout = tw.tickDuration + } else if timeout > tw.wheelDuration { + // We aren't handling timeouts greater than the wheels duration + timeout = tw.wheelDuration + } + + // Find the next highest, rounding up + tick := int(((timeout - 1) / tw.tickDuration) + 1) + + // Add another tick since the current tick may almost be over then map it to the wheel from our + // current position + tick += tw.current + 1 + if tick >= tw.wheelLen { + tick -= tw.wheelLen + } + + return tick +} + +// advance will lock and move the wheel forward by proper number of ticks. +func (tw *TimerWheel) advance(now time.Time) { + if tw.lastTick == nil { + tw.lastTick = &now + } + + // We want to round down + ticks := int(now.Sub(*tw.lastTick) / tw.tickDuration) + adv := ticks + if ticks > tw.wheelLen { + ticks = tw.wheelLen + } + + for i := 0; i < ticks; i++ { + tw.current++ + if tw.current >= tw.wheelLen { + tw.current = 0 + } + + if tw.wheel[tw.current].Head != nil { + // We need to append the expired items as to not starve evicting the oldest ones + if tw.expired.Tail == nil { + tw.expired.Head = tw.wheel[tw.current].Head + tw.expired.Tail = tw.wheel[tw.current].Tail + } else { + tw.expired.Tail.Next = tw.wheel[tw.current].Head + tw.expired.Tail = tw.wheel[tw.current].Tail + } + + tw.wheel[tw.current].Head = nil + tw.wheel[tw.current].Tail = nil + } + } + + // Advance the tick based on duration to avoid losing some accuracy + newTick := tw.lastTick.Add(tw.tickDuration * time.Duration(adv)) + tw.lastTick = &newTick +} diff --git a/timeout_system.go b/timeout_system.go new file mode 100644 index 0000000..e458458 --- /dev/null +++ b/timeout_system.go @@ -0,0 +1,196 @@ +package nebula + +import ( + "sync" + "time" +) + +// How many timer objects should be cached +const systemTimerCacheMax = 50000 + +type SystemTimerWheel struct { + // Current tick + current int + + // Cheat on finding the length of the wheel + wheelLen int + + // Last time we ticked, since we are lazy ticking + lastTick *time.Time + + // Durations of a tick and the entire wheel + tickDuration time.Duration + wheelDuration time.Duration + + // The actual wheel which is just a set of singly linked lists, head/tail pointers + wheel []*SystemTimeoutList + + // Singly linked list of items that have timed out of the wheel + expired *SystemTimeoutList + + // Item cache to avoid garbage collect + itemCache *SystemTimeoutItem + itemsCached int + + lock sync.Mutex +} + +// Represents a tick in the wheel +type SystemTimeoutList struct { + Head *SystemTimeoutItem + Tail *SystemTimeoutItem +} + +// Represents an item within a tick +type SystemTimeoutItem struct { + Item uint32 + Next *SystemTimeoutItem +} + +// Builds a timer wheel and identifies the tick duration and wheel duration from the provided values +// Purge must be called once per entry to actually remove anything +func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel { + //TODO provide an error + //if min >= max { + // return nil + //} + + // Round down and add 1 so we can have the smallest # of ticks in the wheel and still account for a full + // max duration + wLen := int((max / min) + 1) + + tw := SystemTimerWheel{ + wheelLen: wLen, + wheel: make([]*SystemTimeoutList, wLen), + tickDuration: min, + wheelDuration: max, + expired: &SystemTimeoutList{}, + } + + for i := range tw.wheel { + tw.wheel[i] = &SystemTimeoutList{} + } + + return &tw +} + +func (tw *SystemTimerWheel) Add(v uint32, timeout time.Duration) *SystemTimeoutItem { + tw.lock.Lock() + defer tw.lock.Unlock() + + // Check and see if we should progress the tick + //tw.advance(time.Now()) + + i := tw.findWheel(timeout) + + // Try to fetch off the cache + ti := tw.itemCache + if ti != nil { + tw.itemCache = ti.Next + ti.Next = nil + tw.itemsCached-- + } else { + ti = &SystemTimeoutItem{} + } + + // Relink and return + ti.Item = v + ti.Next = tw.wheel[i].Head + tw.wheel[i].Head = ti + + if tw.wheel[i].Tail == nil { + tw.wheel[i].Tail = ti + } + + return ti +} + +func (tw *SystemTimerWheel) Purge() interface{} { + tw.lock.Lock() + defer tw.lock.Unlock() + + if tw.expired.Head == nil { + return nil + } + + ti := tw.expired.Head + tw.expired.Head = ti.Next + + if tw.expired.Head == nil { + tw.expired.Tail = nil + } + + p := ti.Item + + // Clear out the items references + ti.Item = 0 + ti.Next = nil + + // Maybe cache it for later + if tw.itemsCached < systemTimerCacheMax { + ti.Next = tw.itemCache + tw.itemCache = ti + tw.itemsCached++ + } + + return p +} + +func (tw *SystemTimerWheel) findWheel(timeout time.Duration) (i int) { + if timeout < tw.tickDuration { + // Can't track anything below the set resolution + timeout = tw.tickDuration + } else if timeout > tw.wheelDuration { + // We aren't handling timeouts greater than the wheels duration + timeout = tw.wheelDuration + } + + // Find the next highest, rounding up + tick := int(((timeout - 1) / tw.tickDuration) + 1) + + // Add another tick since the current tick may almost be over then map it to the wheel from our + // current position + tick += tw.current + 1 + if tick >= tw.wheelLen { + tick -= tw.wheelLen + } + + return tick +} + +func (tw *SystemTimerWheel) advance(now time.Time) { + tw.lock.Lock() + defer tw.lock.Unlock() + + if tw.lastTick == nil { + tw.lastTick = &now + } + + // We want to round down + ticks := int(now.Sub(*tw.lastTick) / tw.tickDuration) + //l.Infoln("Ticks: ", ticks) + for i := 0; i < ticks; i++ { + tw.current++ + //l.Infoln("Tick: ", tw.current) + if tw.current >= tw.wheelLen { + tw.current = 0 + } + + // We need to append the expired items as to not starve evicting the oldest ones + if tw.expired.Tail == nil { + tw.expired.Head = tw.wheel[tw.current].Head + tw.expired.Tail = tw.wheel[tw.current].Tail + } else { + tw.expired.Tail.Next = tw.wheel[tw.current].Head + if tw.wheel[tw.current].Tail != nil { + tw.expired.Tail = tw.wheel[tw.current].Tail + } + } + + //l.Infoln("Head: ", tw.expired.Head, "Tail: ", tw.expired.Tail) + tw.wheel[tw.current].Head = nil + tw.wheel[tw.current].Tail = nil + + tw.lastTick = &now + } +} diff --git a/timeout_system_test.go b/timeout_system_test.go new file mode 100644 index 0000000..712725d --- /dev/null +++ b/timeout_system_test.go @@ -0,0 +1,134 @@ +package nebula + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewSystemTimerWheel(t *testing.T) { + // Make sure we get an object we expect + tw := NewSystemTimerWheel(time.Second, time.Second*10) + assert.Equal(t, 11, tw.wheelLen) + assert.Equal(t, 0, tw.current) + assert.Nil(t, tw.lastTick) + assert.Equal(t, time.Second*1, tw.tickDuration) + assert.Equal(t, time.Second*10, tw.wheelDuration) + assert.Len(t, tw.wheel, 11) + + // Assert the math is correct + tw = NewSystemTimerWheel(time.Second*3, time.Second*10) + assert.Equal(t, 4, tw.wheelLen) + + tw = NewSystemTimerWheel(time.Second*120, time.Minute*10) + assert.Equal(t, 6, tw.wheelLen) +} + +func TestSystemTimerWheel_findWheel(t *testing.T) { + tw := NewSystemTimerWheel(time.Second, time.Second*10) + assert.Len(t, tw.wheel, 11) + + // Current + tick + 1 since we don't know how far into current we are + assert.Equal(t, 2, tw.findWheel(time.Second*1)) + + // Scale up to min duration + assert.Equal(t, 2, tw.findWheel(time.Millisecond*1)) + + // Make sure we hit that last index + assert.Equal(t, 0, tw.findWheel(time.Second*10)) + + // Scale down to max duration + assert.Equal(t, 0, tw.findWheel(time.Second*11)) + + tw.current = 1 + // Make sure we account for the current position properly + assert.Equal(t, 3, tw.findWheel(time.Second*1)) + assert.Equal(t, 1, tw.findWheel(time.Second*10)) +} + +func TestSystemTimerWheel_Add(t *testing.T) { + tw := NewSystemTimerWheel(time.Second, time.Second*10) + + fp1 := ip2int(net.ParseIP("1.2.3.4")) + tw.Add(fp1, time.Second*1) + + // Make sure we set head and tail properly + assert.NotNil(t, tw.wheel[2]) + assert.Equal(t, fp1, tw.wheel[2].Head.Item) + assert.Nil(t, tw.wheel[2].Head.Next) + assert.Equal(t, fp1, tw.wheel[2].Tail.Item) + assert.Nil(t, tw.wheel[2].Tail.Next) + + // Make sure we only modify head + fp2 := ip2int(net.ParseIP("1.2.3.4")) + tw.Add(fp2, time.Second*1) + assert.Equal(t, fp2, tw.wheel[2].Head.Item) + assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item) + assert.Equal(t, fp1, tw.wheel[2].Tail.Item) + assert.Nil(t, tw.wheel[2].Tail.Next) + + // Make sure we use free'd items first + tw.itemCache = &SystemTimeoutItem{} + tw.itemsCached = 1 + tw.Add(fp2, time.Second*1) + assert.Nil(t, tw.itemCache) + assert.Equal(t, 0, tw.itemsCached) +} + +func TestSystemTimerWheel_Purge(t *testing.T) { + // First advance should set the lastTick and do nothing else + tw := NewSystemTimerWheel(time.Second, time.Second*10) + assert.Nil(t, tw.lastTick) + tw.advance(time.Now()) + assert.NotNil(t, tw.lastTick) + assert.Equal(t, 0, tw.current) + + fps := []uint32{9, 10, 11, 12} + + //fp1 := ip2int(net.ParseIP("1.2.3.4")) + + tw.Add(fps[0], time.Second*1) + tw.Add(fps[1], time.Second*1) + tw.Add(fps[2], time.Second*2) + tw.Add(fps[3], time.Second*2) + + ta := time.Now().Add(time.Second * 3) + lastTick := *tw.lastTick + tw.advance(ta) + assert.Equal(t, 3, tw.current) + assert.True(t, tw.lastTick.After(lastTick)) + + // Make sure we get all 4 packets back + for i := 0; i < 4; i++ { + assert.Contains(t, fps, tw.Purge()) + } + + // Make sure there aren't any leftover + assert.Nil(t, tw.Purge()) + assert.Nil(t, tw.expired.Head) + assert.Nil(t, tw.expired.Tail) + + // Make sure we cached the free'd items + assert.Equal(t, 4, tw.itemsCached) + ci := tw.itemCache + for i := 0; i < 4; i++ { + assert.NotNil(t, ci) + ci = ci.Next + } + assert.Nil(t, ci) + + // Lets make sure we roll over properly + ta = ta.Add(time.Second * 5) + tw.advance(ta) + assert.Equal(t, 8, tw.current) + + ta = ta.Add(time.Second * 2) + tw.advance(ta) + assert.Equal(t, 10, tw.current) + + ta = ta.Add(time.Second * 1) + tw.advance(ta) + assert.Equal(t, 0, tw.current) +} diff --git a/timeout_test.go b/timeout_test.go new file mode 100644 index 0000000..6b862a4 --- /dev/null +++ b/timeout_test.go @@ -0,0 +1,138 @@ +package nebula + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestNewTimerWheel(t *testing.T) { + // Make sure we get an object we expect + tw := NewTimerWheel(time.Second, time.Second*10) + assert.Equal(t, 11, tw.wheelLen) + assert.Equal(t, 0, tw.current) + assert.Nil(t, tw.lastTick) + assert.Equal(t, time.Second*1, tw.tickDuration) + assert.Equal(t, time.Second*10, tw.wheelDuration) + assert.Len(t, tw.wheel, 11) + + // Assert the math is correct + tw = NewTimerWheel(time.Second*3, time.Second*10) + assert.Equal(t, 4, tw.wheelLen) + + tw = NewTimerWheel(time.Second*120, time.Minute*10) + assert.Equal(t, 6, tw.wheelLen) +} + +func TestTimerWheel_findWheel(t *testing.T) { + tw := NewTimerWheel(time.Second, time.Second*10) + assert.Len(t, tw.wheel, 11) + + // Current + tick + 1 since we don't know how far into current we are + assert.Equal(t, 2, tw.findWheel(time.Second*1)) + + // Scale up to min duration + assert.Equal(t, 2, tw.findWheel(time.Millisecond*1)) + + // Make sure we hit that last index + assert.Equal(t, 0, tw.findWheel(time.Second*10)) + + // Scale down to max duration + assert.Equal(t, 0, tw.findWheel(time.Second*11)) + + tw.current = 1 + // Make sure we account for the current position properly + assert.Equal(t, 3, tw.findWheel(time.Second*1)) + assert.Equal(t, 1, tw.findWheel(time.Second*10)) +} + +func TestTimerWheel_Add(t *testing.T) { + tw := NewTimerWheel(time.Second, time.Second*10) + + fp1 := FirewallPacket{} + tw.Add(fp1, time.Second*1) + + // Make sure we set head and tail properly + assert.NotNil(t, tw.wheel[2]) + assert.Equal(t, fp1, tw.wheel[2].Head.Packet) + assert.Nil(t, tw.wheel[2].Head.Next) + assert.Equal(t, fp1, tw.wheel[2].Tail.Packet) + assert.Nil(t, tw.wheel[2].Tail.Next) + + // Make sure we only modify head + fp2 := FirewallPacket{} + tw.Add(fp2, time.Second*1) + assert.Equal(t, fp2, tw.wheel[2].Head.Packet) + assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet) + assert.Equal(t, fp1, tw.wheel[2].Tail.Packet) + assert.Nil(t, tw.wheel[2].Tail.Next) + + // Make sure we use free'd items first + tw.itemCache = &TimeoutItem{} + tw.itemsCached = 1 + tw.Add(fp2, time.Second*1) + assert.Nil(t, tw.itemCache) + assert.Equal(t, 0, tw.itemsCached) +} + +func TestTimerWheel_Purge(t *testing.T) { + // First advance should set the lastTick and do nothing else + tw := NewTimerWheel(time.Second, time.Second*10) + assert.Nil(t, tw.lastTick) + tw.advance(time.Now()) + assert.NotNil(t, tw.lastTick) + assert.Equal(t, 0, tw.current) + + fps := []FirewallPacket{ + {LocalIP: 1}, + {LocalIP: 2}, + {LocalIP: 3}, + {LocalIP: 4}, + } + + tw.Add(fps[0], time.Second*1) + tw.Add(fps[1], time.Second*1) + tw.Add(fps[2], time.Second*2) + tw.Add(fps[3], time.Second*2) + + ta := time.Now().Add(time.Second * 3) + lastTick := *tw.lastTick + tw.advance(ta) + assert.Equal(t, 3, tw.current) + assert.True(t, tw.lastTick.After(lastTick)) + + // Make sure we get all 4 packets back + for i := 0; i < 4; i++ { + p, has := tw.Purge() + assert.True(t, has) + assert.Equal(t, fps[i], p) + } + + // Make sure there aren't any leftover + _, ok := tw.Purge() + assert.False(t, ok) + assert.Nil(t, tw.expired.Head) + assert.Nil(t, tw.expired.Tail) + + // Make sure we cached the free'd items + assert.Equal(t, 4, tw.itemsCached) + ci := tw.itemCache + for i := 0; i < 4; i++ { + assert.NotNil(t, ci) + ci = ci.Next + } + assert.Nil(t, ci) + + // Lets make sure we roll over properly + ta = ta.Add(time.Second * 5) + tw.advance(ta) + assert.Equal(t, 8, tw.current) + + ta = ta.Add(time.Second * 2) + tw.advance(ta) + assert.Equal(t, 10, tw.current) + + ta = ta.Add(time.Second * 1) + tw.advance(ta) + assert.Equal(t, 0, tw.current) +} diff --git a/tun_common.go b/tun_common.go new file mode 100644 index 0000000..0731968 --- /dev/null +++ b/tun_common.go @@ -0,0 +1,108 @@ +package nebula + +import ( + "fmt" + "net" + "strconv" +) + +type route struct { + mtu int + route *net.IPNet +} + +func parseRoutes(config *Config, network *net.IPNet) ([]route, error) { + var err error + + r := config.Get("tun.routes") + if r == nil { + return []route{}, nil + } + + rawRoutes, ok := r.([]interface{}) + if !ok { + return nil, fmt.Errorf("tun.routes is not an array") + } + + if len(rawRoutes) < 1 { + return []route{}, nil + } + + routes := make([]route, len(rawRoutes)) + for i, r := range rawRoutes { + m, ok := r.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1) + } + + rMtu, ok := m["mtu"] + if !ok { + return nil, fmt.Errorf("entry %v.mtu in tun.routes is not present", i+1) + } + + mtu, ok := rMtu.(int) + if !ok { + mtu, err = strconv.Atoi(rMtu.(string)) + if err != nil { + return nil, fmt.Errorf("entry %v.mtu in tun.routes is not an integer: %v", i+1, err) + } + } + + if mtu < 500 { + return nil, fmt.Errorf("entry %v.mtu in tun.routes is below 500: %v", i+1, mtu) + } + + rRoute, ok := m["route"] + if !ok { + return nil, fmt.Errorf("entry %v.route in tun.routes is not present", i+1) + } + + r := route{ + mtu: mtu, + } + + _, r.route, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) + if err != nil { + return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) + } + + if !ipWithin(network, r.route) { + return nil, fmt.Errorf( + "entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v", + i+1, + r.route.String(), + network.String(), + ) + } + + routes[i] = r + } + + return routes, nil +} + +func ipWithin(o *net.IPNet, i *net.IPNet) bool { + // Make sure o contains the lowest form of i + if !o.Contains(i.IP.Mask(i.Mask)) { + return false + } + + // Find the max ip in i + ip4 := i.IP.To4() + if ip4 == nil { + return false + } + + last := make(net.IP, len(ip4)) + copy(last, ip4) + for x := range ip4 { + last[x] |= ^i.Mask[x] + } + + // Make sure o contains the max + if !o.Contains(last) { + return false + } + + return true +} diff --git a/tun_darwin.go b/tun_darwin.go new file mode 100644 index 0000000..43fc4fd --- /dev/null +++ b/tun_darwin.go @@ -0,0 +1,59 @@ +package nebula + +import ( + "fmt" + "net" + "os/exec" + "strconv" + + "github.com/songgao/water" +) + +type Tun struct { + Device string + Cidr *net.IPNet + MTU int + + *water.Interface +} + +func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) { + if len(routes) > 0 { + return nil, fmt.Errorf("Route MTU not supported in Darwin") + } + // NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate() + return &Tun{ + Cidr: cidr, + MTU: defaultMTU, + }, nil +} + +func (c *Tun) Activate() error { + var err error + c.Interface, err = water.New(water.Config{ + DeviceType: water.TUN, + }) + if err != nil { + return fmt.Errorf("Activate failed: %v", err) + } + + c.Device = c.Interface.Name() + + // TODO use syscalls instead of exec.Command + if err = exec.Command("ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil { + return fmt.Errorf("failed to run 'ifconfig': %s", err) + } + if err = exec.Command("route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil { + return fmt.Errorf("failed to run 'route add': %s", err) + } + if err = exec.Command("ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil { + return fmt.Errorf("failed to run 'ifconfig': %s", err) + } + + return nil +} + +func (c *Tun) WriteRaw(b []byte) error { + _, err := c.Write(b) + return err +} diff --git a/tun_linux.go b/tun_linux.go new file mode 100644 index 0000000..2a1a197 --- /dev/null +++ b/tun_linux.go @@ -0,0 +1,249 @@ +package nebula + +import ( + "fmt" + "io" + "net" + "os" + "strings" + "syscall" + "unsafe" + + "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" +) + +type Tun struct { + io.ReadWriteCloser + fd int + Device string + Cidr *net.IPNet + MaxMTU int + DefaultMTU int + TXQueueLen int + Routes []route +} + +type ifReq struct { + Name [16]byte + Flags uint16 + pad [8]byte +} + +func ioctl(a1, a2, a3 uintptr) error { + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3) + if errno != 0 { + return errno + } + return nil +} + +/* +func ipv4(addr string) (o [4]byte, err error) { + ip := net.ParseIP(addr).To4() + if ip == nil { + err = fmt.Errorf("failed to parse addr %s", addr) + return + } + for i, b := range ip { + o[i] = b + } + return +} +*/ + +const ( + cIFF_TUN = 0x0001 + cIFF_NO_PI = 0x1000 +) + +type ifreqAddr struct { + Name [16]byte + Addr syscall.RawSockaddrInet4 + pad [8]byte +} + +type ifreqMTU struct { + Name [16]byte + MTU int + pad [8]byte +} + +type ifreqQLEN struct { + Name [16]byte + Value int + pad [8]byte +} + +func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) { + fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) + if err != nil { + return nil, err + } + + var req ifReq + req.Flags = uint16(cIFF_TUN | cIFF_NO_PI) + copy(req.Name[:], deviceName) + if err = ioctl(uintptr(fd), uintptr(syscall.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { + return + } + name := strings.Trim(string(req.Name[:]), "\x00") + + file := os.NewFile(uintptr(fd), "/dev/net/tun") + + maxMTU := defaultMTU + for _, r := range routes { + if r.mtu > maxMTU { + maxMTU = r.mtu + } + } + + ifce = &Tun{ + ReadWriteCloser: file, + fd: int(file.Fd()), + Device: name, + Cidr: cidr, + MaxMTU: maxMTU, + DefaultMTU: defaultMTU, + TXQueueLen: txQueueLen, + Routes: routes, + } + return +} + +func (c *Tun) WriteRaw(b []byte) error { + var nn int + for { + max := len(b) + n, err := syscall.Write(c.fd, b[nn:max]) + if n > 0 { + nn += n + } + if nn == len(b) { + return err + } + + if err != nil { + return err + } + + if n == 0 { + return io.ErrUnexpectedEOF + } + } +} + +func (c Tun) deviceBytes() (o [16]byte) { + for i, c := range c.Device { + o[i] = byte(c) + } + return +} + +func (c Tun) Activate() error { + devName := c.deviceBytes() + + var addr, mask [4]byte + + copy(addr[:], c.Cidr.IP.To4()) + copy(mask[:], c.Cidr.Mask) + + s, err := syscall.Socket( + syscall.AF_INET, + syscall.SOCK_DGRAM, + syscall.IPPROTO_IP, + ) + if err != nil { + return err + } + fd := uintptr(s) + + ifra := ifreqAddr{ + Name: devName, + Addr: syscall.RawSockaddrInet4{ + Family: syscall.AF_INET, + Addr: addr, + }, + } + + // Set the device ip address + if err = ioctl(fd, syscall.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { + return err + } + + // Set the device network + ifra.Addr.Addr = mask + if err = ioctl(fd, syscall.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { + return err + } + + // Set the device name + ifrf := ifReq{Name: devName} + if err = ioctl(fd, syscall.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + return err + } + + // Set the MTU on the device + ifm := ifreqMTU{Name: devName, MTU: c.MaxMTU} + if err = ioctl(fd, syscall.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { + return err + } + + // Set the transmit queue length + ifrq := ifreqQLEN{Name: devName, Value: c.TXQueueLen} + if err = ioctl(fd, syscall.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { + return err + } + + // Bring up the interface + ifrf.Flags = ifrf.Flags | syscall.IFF_UP + if err = ioctl(fd, syscall.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + return err + } + + // Set the routes + link, err := netlink.LinkByName(c.Device) + if err != nil { + return err + } + + // Default route + dr := &net.IPNet{IP: c.Cidr.IP.Mask(c.Cidr.Mask), Mask: c.Cidr.Mask} + nr := netlink.Route{ + LinkIndex: link.Attrs().Index, + Dst: dr, + MTU: c.DefaultMTU, + Scope: unix.RT_SCOPE_LINK, + Src: c.Cidr.IP, + Protocol: unix.RTPROT_KERNEL, + Table: unix.RT_TABLE_MAIN, + Type: unix.RTN_UNICAST, + } + err = netlink.RouteReplace(&nr) + if err != nil { + return fmt.Errorf("failed to set mtu %v on the default route %v; %v", c.DefaultMTU, dr, err) + } + + // Path routes + for _, r := range c.Routes { + nr := netlink.Route{ + LinkIndex: link.Attrs().Index, + Dst: r.route, + MTU: r.mtu, + Scope: unix.RT_SCOPE_LINK, + } + + err = netlink.RouteAdd(&nr) + if err != nil { + return fmt.Errorf("failed to set mtu %v on route %v; %v", r.mtu, r.route, err) + } + } + + // Run the interface + ifrf.Flags = ifrf.Flags | syscall.IFF_UP | syscall.IFF_RUNNING + if err = ioctl(fd, syscall.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + return err + } + + return nil +} diff --git a/tun_test.go b/tun_test.go new file mode 100644 index 0000000..09651e5 --- /dev/null +++ b/tun_test.go @@ -0,0 +1,102 @@ +package nebula + +import ( + "github.com/stretchr/testify/assert" + "net" + "testing" +) + +func Test_parseRoutes(t *testing.T) { + c := NewConfig() + _, n, _ := net.ParseCIDR("10.0.0.0/24") + + // test no routes config + routes, err := parseRoutes(c, n) + assert.Nil(t, err) + assert.Len(t, routes, 0) + + // not an array + c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} + routes, err = parseRoutes(c, n) + assert.Nil(t, routes) + assert.EqualError(t, err, "tun.routes is not an array") + + // no routes + c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} + routes, err = parseRoutes(c, n) + assert.Nil(t, err) + assert.Len(t, routes, 0) + + // weird route + c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} + routes, err = parseRoutes(c, n) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1 in tun.routes is invalid") + + // no mtu + c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} + routes, err = parseRoutes(c, n) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present") + + // bad mtu + c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} + routes, err = parseRoutes(c, n) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") + + // low mtu + c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} + routes, err = parseRoutes(c, n) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") + + // missing route + c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} + routes, err = parseRoutes(c, n) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1.route in tun.routes is not present") + + // unparsable route + c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} + routes, err = parseRoutes(c, n) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope") + + // below network range + c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} + routes, err = parseRoutes(c, n) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24") + + // above network range + c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} + routes, err = parseRoutes(c, n) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24") + + // happy case + c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ + map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"}, + map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, + }} + routes, err = parseRoutes(c, n) + assert.Nil(t, err) + assert.Len(t, routes, 2) + + tested := 0 + for _, r := range routes { + if r.mtu == 8000 { + assert.Equal(t, "10.0.0.1/32", r.route.String()) + tested++ + } else { + assert.Equal(t, 9000, r.mtu) + assert.Equal(t, "10.0.0.0/29", r.route.String()) + tested++ + } + } + + if tested != 2 { + t.Fatal("Did not see both routes") + } +} diff --git a/tun_windows.go b/tun_windows.go new file mode 100644 index 0000000..6c740a0 --- /dev/null +++ b/tun_windows.go @@ -0,0 +1,72 @@ +package nebula + +import ( + "fmt" + "net" + "os/exec" + + "github.com/songgao/water" +) + +type Tun struct { + Device string + Cidr *net.IPNet + MTU int + + *water.Interface +} + +func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) { + if len(routes) > 0 { + return nil, fmt.Errorf("Route MTU not supported in Windows") + } + // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() + return &Tun{ + Cidr: cidr, + MTU: defaultMTU, + }, nil +} + +func (c *Tun) Activate() error { + var err error + c.Interface, err = water.New(water.Config{ + DeviceType: water.TUN, + PlatformSpecificParams: water.PlatformSpecificParams{ + ComponentID: "tap0901", + Network: c.Cidr.String(), + }, + }) + if err != nil { + return fmt.Errorf("Activate failed: %v", err) + } + + c.Device = c.Interface.Name() + + // TODO use syscalls instead of exec.Command + err = exec.Command( + "netsh", "interface", "ipv4", "set", "address", + fmt.Sprintf("name=%s", c.Device), + "source=static", + fmt.Sprintf("addr=%s", c.Cidr.IP), + fmt.Sprintf("mask=%s", net.IP(c.Cidr.Mask)), + "gateway=none", + ).Run() + if err != nil { + return fmt.Errorf("failed to run 'netsh' to set address: %s", err) + } + err = exec.Command( + "netsh", "interface", "ipv4", "set", "interface", + c.Device, + fmt.Sprintf("mtu=%d", c.MTU), + ).Run() + if err != nil { + return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err) + } + + return nil +} + +func (c *Tun) WriteRaw(b []byte) error { + _, err := c.Write(b) + return err +} diff --git a/udp_darwin.go b/udp_darwin.go new file mode 100644 index 0000000..61fda7a --- /dev/null +++ b/udp_darwin.go @@ -0,0 +1,34 @@ +package nebula + +// Darwin support is primarily implemented in udp_generic, besides NewListenConfig + +import ( + "fmt" + "net" + "syscall" + + "golang.org/x/sys/unix" +) + +func NewListenConfig(multi bool) net.ListenConfig { + return net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + if multi { + var controlErr error + err := c.Control(func(fd uintptr) { + if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { + controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err) + return + } + }) + if err != nil { + return err + } + if controlErr != nil { + return controlErr + } + } + return nil + }, + } +} diff --git a/udp_generic.go b/udp_generic.go new file mode 100644 index 0000000..0c988ab --- /dev/null +++ b/udp_generic.go @@ -0,0 +1,123 @@ +// +build !linux + +// udp_generic implements the nebula UDP interface in pure Go stdlib. This +// means it can be used on platforms like Darwin and Windows. + +package nebula + +import ( + "context" + "encoding/binary" + "fmt" + "net" + "strconv" + "strings" +) + +type udpAddr struct { + net.UDPAddr +} + +type udpConn struct { + *net.UDPConn +} + +func NewUDPAddr(ip uint32, port uint16) *udpAddr { + return &udpAddr{ + UDPAddr: net.UDPAddr{ + IP: int2ip(ip), + Port: int(port), + }, + } +} + +func NewUDPAddrFromString(s string) *udpAddr { + p := strings.Split(s, ":") + if len(p) < 2 { + return nil + } + + port, _ := strconv.Atoi(p[1]) + return &udpAddr{ + UDPAddr: net.UDPAddr{ + IP: net.ParseIP(p[0]), + Port: port, + }, + } +} + +func NewListener(ip string, port int, multi bool) (*udpConn, error) { + lc := NewListenConfig(multi) + pc, err := lc.ListenPacket(context.TODO(), "udp4", fmt.Sprintf("%s:%d", ip, port)) + if err != nil { + return nil, err + } + if uc, ok := pc.(*net.UDPConn); ok { + return &udpConn{UDPConn: uc}, nil + } + return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) +} + +func (ua *udpAddr) Equals(t *udpAddr) bool { + if t == nil || ua == nil { + return t == nil && ua == nil + } + return ua.IP.Equal(t.IP) && ua.Port == t.Port +} + +func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error { + _, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr) + return err +} + +func (uc *udpConn) LocalAddr() (*udpAddr, error) { + a := uc.UDPConn.LocalAddr() + + switch v := a.(type) { + case *net.UDPAddr: + return &udpAddr{UDPAddr: *v}, nil + default: + return nil, fmt.Errorf("LocalAddr returned: %#v", a) + } +} + +func (u *udpConn) reloadConfig(c *Config) { + // TODO +} + +type rawMessage struct { + Len uint32 +} + +func (u *udpConn) ListenOut(f *Interface) { + plaintext := make([]byte, mtu) + buffer := make([]byte, mtu) + header := &Header{} + fwPacket := &FirewallPacket{} + udpAddr := &udpAddr{} + nb := make([]byte, 12, 12) + + for { + // Just read one packet at a time + n, rua, err := u.ReadFromUDP(buffer) + if err != nil { + l.WithError(err).Error("Failed to read packets") + continue + } + + udpAddr.UDPAddr = *rua + f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, nb) + } +} + +func udp2ip(addr *udpAddr) net.IP { + return addr.IP +} + +func udp2ipInt(addr *udpAddr) uint32 { + return binary.BigEndian.Uint32(addr.IP.To4()) +} + +func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool { + return !addr.Equals(newaddr) +} diff --git a/udp_linux.go b/udp_linux.go new file mode 100644 index 0000000..593b896 --- /dev/null +++ b/udp_linux.go @@ -0,0 +1,315 @@ +package nebula + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "net" + "strconv" + "strings" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" +) + +//TODO: make it support reload as best you can! + +type udpConn struct { + sysFd int +} + +type udpAddr struct { + IP uint32 + Port uint16 +} + +func NewUDPAddr(ip uint32, port uint16) *udpAddr { + return &udpAddr{IP: ip, Port: port} +} + +func NewUDPAddrFromString(s string) *udpAddr { + p := strings.Split(s, ":") + if len(p) < 2 { + return nil + } + + port, _ := strconv.Atoi(p[1]) + return &udpAddr{ + IP: ip2int(net.ParseIP(p[0])), + Port: uint16(port), + } +} + +type rawSockaddr struct { + Family uint16 + Data [14]uint8 +} + +type rawSockaddrAny struct { + Addr rawSockaddr + Pad [96]int8 +} + +var x int + +func NewListener(ip string, port int, multi bool) (*udpConn, error) { + syscall.ForkLock.RLock() + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) + if err == nil { + syscall.CloseOnExec(fd) + } + syscall.ForkLock.RUnlock() + + if err != nil { + syscall.Close(fd) + return nil, err + } + + var lip [4]byte + copy(lip[:], net.ParseIP(ip).To4()) + + if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, 0x0F, 1); err != nil { + return nil, err + } + + if err = syscall.Bind(fd, &syscall.SockaddrInet4{Port: port}); err != nil { + return nil, err + } + + // SO_REUSEADDR does not load balance so we use PORT + if multi { + if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { + return nil, err + } + } + + //TODO: this may be useful for forcing threads into specific cores + //syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, unix.SO_INCOMING_CPU, x) + //v, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, unix.SO_INCOMING_CPU) + //l.Println(v, err) + + return &udpConn{sysFd: fd}, err +} + +func (u *udpConn) SetRecvBuffer(n int) error { + return syscall.SetsockoptInt(u.sysFd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, n) +} + +func (u *udpConn) SetSendBuffer(n int) error { + return syscall.SetsockoptInt(u.sysFd, syscall.SOL_SOCKET, syscall.SO_SNDBUFFORCE, n) +} + +func (u *udpConn) GetRecvBuffer() (int, error) { + return syscall.GetsockoptInt(int(u.sysFd), syscall.SOL_SOCKET, syscall.SO_RCVBUF) +} + +func (u *udpConn) GetSendBuffer() (int, error) { + return syscall.GetsockoptInt(int(u.sysFd), syscall.SOL_SOCKET, syscall.SO_SNDBUF) +} + +func (u *udpConn) LocalAddr() (*udpAddr, error) { + var rsa rawSockaddrAny + var rLen = syscall.SizeofSockaddrAny + + _, _, err := syscall.Syscall( + syscall.SYS_GETSOCKNAME, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&rsa)), + uintptr(unsafe.Pointer(&rLen)), + ) + + if err != 0 { + return nil, err + } + + addr := &udpAddr{} + if rsa.Addr.Family == syscall.AF_INET { + addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1]) + addr.IP = uint32(rsa.Addr.Data[2])<<24 + uint32(rsa.Addr.Data[3])<<16 + uint32(rsa.Addr.Data[4])<<8 + uint32(rsa.Addr.Data[5]) + } else { + addr.Port = 0 + addr.IP = 0 + } + return addr, nil +} + +func (u *udpConn) ListenOut(f *Interface) { + plaintext := make([]byte, mtu) + header := &Header{} + fwPacket := &FirewallPacket{} + udpAddr := &udpAddr{} + nb := make([]byte, 12, 12) + + //TODO: should we track this? + //metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015)) + msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize) + + for { + n, err := u.ReadMulti(msgs) + if err != nil { + l.WithError(err).Error("Failed to read packets") + continue + } + + //metric.Update(int64(n)) + for i := 0; i < n; i++ { + udpAddr.IP = binary.BigEndian.Uint32(names[i][4:8]) + udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) + + f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, nb) + } + } +} + +func (u *udpConn) Read(addr *udpAddr, b []byte) ([]byte, error) { + var rsa rawSockaddrAny + var rLen = syscall.SizeofSockaddrAny + + for { + n, _, err := syscall.Syscall6( + syscall.SYS_RECVFROM, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&b[0])), + uintptr(len(b)), + uintptr(0), + uintptr(unsafe.Pointer(&rsa)), + uintptr(unsafe.Pointer(&rLen)), + ) + + if err != 0 { + return nil, &net.OpError{Op: "read", Err: err} + } + + if rsa.Addr.Family == syscall.AF_INET { + addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1]) + addr.IP = uint32(rsa.Addr.Data[2])<<24 + uint32(rsa.Addr.Data[3])<<16 + uint32(rsa.Addr.Data[4])<<8 + uint32(rsa.Addr.Data[5]) + } else { + addr.Port = 0 + addr.IP = 0 + } + + return b[:n], nil + } +} + +func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) { + for { + n, _, err := syscall.Syscall6( + syscall.SYS_RECVMMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(len(msgs)), + unix.MSG_WAITFORONE, + 0, + 0, + ) + + if err != 0 { + return 0, &net.OpError{Op: "recvmmsg", Err: err} + } + + return int(n), nil + } +} + +func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error { + var rsa syscall.RawSockaddrInet4 + + //TODO: sometimes addr is nil! + rsa.Family = syscall.AF_INET + p := (*[2]byte)(unsafe.Pointer(&rsa.Port)) + p[0] = byte(addr.Port >> 8) + p[1] = byte(addr.Port) + + rsa.Addr[0] = byte(addr.IP & 0xff000000 >> 24) + rsa.Addr[1] = byte(addr.IP & 0x00ff0000 >> 16) + rsa.Addr[2] = byte(addr.IP & 0x0000ff00 >> 8) + rsa.Addr[3] = byte(addr.IP & 0x000000ff) + + for { + _, _, err := syscall.Syscall6( + syscall.SYS_SENDTO, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&b[0])), + uintptr(len(b)), + uintptr(0), + uintptr(unsafe.Pointer(&rsa)), + uintptr(syscall.SizeofSockaddrInet4), + ) + + if err != 0 { + return &net.OpError{Op: "sendto", Err: err} + } + + //TODO: handle incomplete writes + + return nil + } +} + +func (u *udpConn) reloadConfig(c *Config) { + b := c.GetInt("listen.read_buffer", 0) + if b > 0 { + err := u.SetRecvBuffer(b) + if err == nil { + s, err := u.GetRecvBuffer() + if err == nil { + l.WithField("size", s).Info("listen.read_buffer was set") + } else { + l.WithError(err).Warn("Failed to get listen.read_buffer") + } + } else { + l.WithError(err).Error("Failed to set listen.read_buffer") + } + } + + b = c.GetInt("listen.write_buffer", 0) + if b > 0 { + err := u.SetSendBuffer(b) + if err == nil { + s, err := u.GetSendBuffer() + if err == nil { + l.WithField("size", s).Info("listen.write_buffer was set") + } else { + l.WithError(err).Warn("Failed to get listen.write_buffer") + } + } else { + l.WithError(err).Error("Failed to set listen.write_buffer") + } + } +} + +func (ua *udpAddr) Equals(t *udpAddr) bool { + if t == nil || ua == nil { + return t == nil && ua == nil + } + return ua.IP == t.IP && ua.Port == t.Port +} + +func (ua *udpAddr) Copy() *udpAddr { + return &udpAddr{ + Port: ua.Port, + IP: ua.IP, + } +} + +func (ua *udpAddr) String() string { + return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port) +} + +func (ua *udpAddr) MarshalJSON() ([]byte, error) { + return json.Marshal(m{"ip": int2ip(ua.IP), "port": ua.Port}) +} + +func udp2ip(addr *udpAddr) net.IP { + return int2ip(addr.IP) +} + +func udp2ipInt(addr *udpAddr) uint32 { + return addr.IP +} + +func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool { + return !addr.Equals(newaddr) +} diff --git a/udp_linux_amd64.go b/udp_linux_amd64.go new file mode 100644 index 0000000..a61e0c8 --- /dev/null +++ b/udp_linux_amd64.go @@ -0,0 +1,50 @@ +package nebula + +import "unsafe" + +type iovec struct { + Base *byte + Len uint64 +} + +type msghdr struct { + Name *byte + Namelen uint32 + Pad0 [4]byte + Iov *iovec + Iovlen uint64 + Control *byte + Controllen uint64 + Flags int32 + Pad1 [4]byte +} + +type rawMessage struct { + Hdr msghdr + Len uint32 + Pad0 [4]byte +} + +func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { + msgs := make([]rawMessage, n) + buffers := make([][]byte, n) + names := make([][]byte, n) + + for i := range msgs { + buffers[i] = make([]byte, mtu) + names[i] = make([]byte, 0x1c) //TODO = sizeofSockaddrInet6 + + //TODO: this is still silly, no need for an array + vs := []iovec{ + {Base: (*byte)(unsafe.Pointer(&buffers[i][0])), Len: uint64(len(buffers[i]))}, + } + + msgs[i].Hdr.Iov = &vs[0] + msgs[i].Hdr.Iovlen = uint64(len(vs)) + + msgs[i].Hdr.Name = (*byte)(unsafe.Pointer(&names[i][0])) + msgs[i].Hdr.Namelen = uint32(len(names[i])) + } + + return msgs, buffers, names +} diff --git a/udp_linux_arm.go b/udp_linux_arm.go new file mode 100644 index 0000000..2641b2c --- /dev/null +++ b/udp_linux_arm.go @@ -0,0 +1,47 @@ +package nebula + +import "unsafe" + +type iovec struct { + Base *byte + Len uint32 +} + +type msghdr struct { + Name *byte + Namelen uint32 + Iov *iovec + Iovlen uint32 + Control *byte + Controllen uint32 + Flags int32 +} + +type rawMessage struct { + Hdr msghdr + Len uint32 +} + +func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { + msgs := make([]rawMessage, n) + buffers := make([][]byte, n) + names := make([][]byte, n) + + for i := range msgs { + buffers[i] = make([]byte, mtu) + names[i] = make([]byte, 0x1c) //TODO = sizeofSockaddrInet6 + + //TODO: this is still silly, no need for an array + vs := []iovec{ + {Base: (*byte)(unsafe.Pointer(&buffers[i][0])), Len: uint32(len(buffers[i]))}, + } + + msgs[i].Hdr.Iov = &vs[0] + msgs[i].Hdr.Iovlen = uint32(len(vs)) + + msgs[i].Hdr.Name = (*byte)(unsafe.Pointer(&names[i][0])) + msgs[i].Hdr.Namelen = uint32(len(names[i])) + } + + return msgs, buffers, names +} diff --git a/udp_linux_arm64.go b/udp_linux_arm64.go new file mode 100644 index 0000000..a61e0c8 --- /dev/null +++ b/udp_linux_arm64.go @@ -0,0 +1,50 @@ +package nebula + +import "unsafe" + +type iovec struct { + Base *byte + Len uint64 +} + +type msghdr struct { + Name *byte + Namelen uint32 + Pad0 [4]byte + Iov *iovec + Iovlen uint64 + Control *byte + Controllen uint64 + Flags int32 + Pad1 [4]byte +} + +type rawMessage struct { + Hdr msghdr + Len uint32 + Pad0 [4]byte +} + +func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { + msgs := make([]rawMessage, n) + buffers := make([][]byte, n) + names := make([][]byte, n) + + for i := range msgs { + buffers[i] = make([]byte, mtu) + names[i] = make([]byte, 0x1c) //TODO = sizeofSockaddrInet6 + + //TODO: this is still silly, no need for an array + vs := []iovec{ + {Base: (*byte)(unsafe.Pointer(&buffers[i][0])), Len: uint64(len(buffers[i]))}, + } + + msgs[i].Hdr.Iov = &vs[0] + msgs[i].Hdr.Iovlen = uint64(len(vs)) + + msgs[i].Hdr.Name = (*byte)(unsafe.Pointer(&names[i][0])) + msgs[i].Hdr.Namelen = uint32(len(names[i])) + } + + return msgs, buffers, names +} diff --git a/udp_windows.go b/udp_windows.go new file mode 100644 index 0000000..6376503 --- /dev/null +++ b/udp_windows.go @@ -0,0 +1,22 @@ +package nebula + +// Windows support is primarily implemented in udp_generic, besides NewListenConfig + +import ( + "fmt" + "net" + "syscall" +) + +func NewListenConfig(multi bool) net.ListenConfig { + return net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + if multi { + // There is no way to support multiple listeners safely on Windows: + // https://docs.microsoft.com/en-us/windows/desktop/winsock/using-so-reuseaddr-and-so-exclusiveaddruse + return fmt.Errorf("multiple udp listeners not supported on windows") + } + return nil + }, + } +}