Public Release

This commit is contained in:
Slack Security Team 2019-11-19 17:00:20 +00:00
commit f22b4b584d
103 changed files with 14825 additions and 0 deletions

10
.gitignore vendored Normal file
View File

@ -0,0 +1,10 @@
/nebula
/nebula-cert
/nebula-arm
/nebula-arm6
/nebula-darwin
/nebula.exe
/cert/*.crt
/cert/*.key
/coverage.out
/cpu.pprof

9
AUTHORS Normal file
View File

@ -0,0 +1,9 @@
# This is the official list of Nebula authors for copyright purposes.
# Names should be added to this file as:
# Name or Organization <email address>
# The email address is not required for organizations.
Slack Technologies, Inc.
Nate Brown <nbrown.us@gmail.com>
Ryan Huber <rhuber@gmail.com>

24
LICENSE Normal file
View File

@ -0,0 +1,24 @@
MIT License
Copyright (c) 2018-2019 Slack Technologies, Inc.
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

77
Makefile Normal file
View File

@ -0,0 +1,77 @@
BUILD_NUMBER ?= dev+$(shell date -u '+%Y%m%d%H%M%S')
GO111MODULE = on
export GO111MODULE
all:
make bin
make bin-arm
make bin-arm6
make bin-arm64
make bin-darwin
make bin-windows
bin:
go build -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula ./cmd/nebula
go build -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula-cert ./cmd/nebula-cert
install:
go install -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula
go install -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula-cert
bin-arm:
GOARCH=arm GOOS=linux go build -o nebula-arm -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula
bin-arm6:
GOARCH=arm GOARM=6 GOOS=linux go build -o nebula-arm6 -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula
bin-arm64:
GOARCH=arm64 GOOS=linux go build -o nebula-arm64 -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula
bin-vagrant:
GOARCH=amd64 GOOS=linux go build -o nebula -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula
GOARCH=amd64 GOOS=linux go build -ldflags "-X main.Build=$(BUILD_NUMBER)" -o ./nebula-cert ./cmd/nebula-cert
bin-darwin:
GOARCH=amd64 GOOS=darwin go build -o nebula-darwin -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula
bin-windows:
GOARCH=amd64 GOOS=windows go build -o nebula.exe -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula
bin-linux:
GOARCH=amd64 GOOS=linux go build -o ./nebula -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula
GOARCH=amd64 GOOS=linux go build -o ./nebula-cert -ldflags "-X main.Build=$(BUILD_NUMBER)" ./cmd/nebula-cert
vet:
go vet -v ./...
test:
go test -v ./...
test-cov-html:
go test -coverprofile=coverage.out
go tool cover -html=coverage.out
bench:
go test -bench=.
bench-cpu:
go test -bench=. -benchtime=5s -cpuprofile=cpu.pprof
go tool pprof go-audit.test cpu.pprof
bench-cpu-long:
go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
go tool pprof go-audit.test cpu.pprof
proto: nebula.pb.go cert/cert.pb.go
nebula.pb.go: nebula.proto .FORCE
go build github.com/golang/protobuf/protoc-gen-go
PATH="$(PWD):$(PATH)" protoc --go_out=. $<
rm protoc-gen-go
cert/cert.pb.go: cert/cert.proto .FORCE
$(MAKE) -C cert cert.pb.go
.FORCE:
.PHONY: test test-cov-html bench bench-cpu bench-cpu-long bin proto
.DEFAULT_GOAL := bin

91
README.md Normal file
View File

@ -0,0 +1,91 @@
## What is Nebula?
Nebula is a scalable overlay networking tool with a focus on performance, simplicity and security.
It lets you seamlessly connect computers anywhere in the world. Nebula is portable, and runs on Linux, OSX, and Windows.
(Also: keep this quiet, but we have an early prototype running on iOS).
It can be used to connect a small number of computers, but is also able to connect tens of thousands of computers.
Nebula incorporates a number of existing concepts like encryption, security groups, certificates,
and tunneling, and each of those individual pieces existed before Nebula in various forms.
What makes Nebula different to existing offerings is that it brings all of these ideas together,
resulting in a sum that is greater than its individual parts.
You can read more about Nebula [here](https://medium.com/p/884110a5579).
## Technical Overview
Nebula is a mutually authenticated peer-to-peer software defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/).
Nebula uses certificates to assert a node's IP address, name, and membership within user-defined groups.
Nebula's user-defined groups allow for provider agnostic traffic filtering between nodes.
Discovery nodes allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs.
Users can move data between nodes in any number of cloud service providers, datacenters, and endpoints, without needing to maintain a particular addressing scheme.
Nebula uses elliptic curve Diffie-Hellman key exchange, and AES-256-GCM in its default configuration.
Nebula was created to provide a mechanism for groups hosts to communicate securely, even across the internet, while enabling expressive firewall definitions similar in style to cloud security groups.
## Getting started (quickly)
To set up a Nebula network, you'll need:
#### 1. The [Nebula binaries](https://github.com/slackhq/nebula/releases) for your specific platform. Specifically you'll need `nebula-cert` and the specific nebula binary for each platform you use.
#### 2. (Optional, but you really should..) At least one discovery node with a routable IP address, which we call a lighthouse.
Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $5/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses.
Once you have launched an instance, ensure that Nebula udp traffic (default port udp/4242) can reach it over the internet.
#### 3. A Nebula certificate authority, which will be the root of trust for a particular Nebula network.
```
./nebula-cert ca -name "Myorganization, Inc"
```
This will create files named `ca.key` and `ca.cert` in the current directory. The `ca.key` file is the most sensitive file you'll create, because it is the key used to sign the certificates for individual nebula nodes/hosts. Please store this file somewhere safe, preferably with strong encryption.
#### 4. Nebula host keys and certificates generated from that certificate authority
This assumes you have three nodes, named lighthouse1, host1, host3. You can name the nodes any way you'd like, including FQDN. You'll also need to choose IP addresses and the associated subnet. In this example, we are creating a nebula network that will use 192.168.100.x/24 as its network range. This example also demonstrates nebula groups, which can later be used to define traffic rules in a nebula network.
```
./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24"
./nebula-cert sign -name "laptop" -ip "192.168.100.2/24" -groups "laptop,home,ssh"
./nebula-cert sign -name "server1" -ip "192.168.100.9/24" -groups "servers"
./nebula-cert sign -name "host3" -ip "192.168.100.9/24"
```
#### 5. Configuration files for each host
Download a copy of the nebula [example configuration](https://github.com/slackhq/nebula/blob/master/examples/config.yaml).
* On the lighthouse node, you'll need to ensure `am_lighthouse: true` is set.
* On the individual hosts, ensure the lighthouse is defined properly in the `static_host_map` section, and is added to the lighthouse `hosts` section.
#### 6. Copy nebula credentials, configuration, and binaries to each host
For each host, copy the nebula binary to the host, along with `config.yaml` from step 5, and the files `ca.crt`, `{host}.crt`, and `{host}.key` from step 2.
**DO NOT COPY `ca.key` TO INDIVIDUAL NODES.**
#### 7. Run nebula on each host
```
./nebula -config /path/to/config.yaml
```
## Building Nebula from source
Download go and clone this repo. Change to the nebula directory.
To build nebula for all platforms:
`make all`
To build nebula for a specific platform (ex, Windows):
`make bin-windows`
See the [Makefile](Makefile) for more details on build targets
## Credits
Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang.

157
bits.go Normal file
View File

@ -0,0 +1,157 @@
package nebula
import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
)
type Bits struct {
length uint64
current uint64
bits []bool
firstSeen bool
lostCounter metrics.Counter
dupeCounter metrics.Counter
outOfWindowCounter metrics.Counter
}
func NewBits(bits uint64) *Bits {
return &Bits{
length: bits,
bits: make([]bool, bits, bits),
current: 0,
lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil),
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
}
}
func (b *Bits) Check(i uint64) bool {
// If i is the next number, return true.
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
return true
}
// If i is within the window, check if it's been set already. The first window will fail this check
if i > b.current-b.length {
return !b.bits[i%b.length]
}
// If i is within the first window
if i < b.length {
return !b.bits[i%b.length]
}
// Not within the window
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
return false
}
func (b *Bits) Update(i uint64) bool {
// If i is the next number, return true and update current.
if i == b.current+1 {
// Report missed packets, we can only understand what was missed after the first window has been gone through
if i > b.length && b.bits[i%b.length] == false {
b.lostCounter.Inc(1)
}
b.bits[i%b.length] = true
b.current = i
return true
}
// If i packet is greater than current but less than the maximum length of our bitmap,
// flip everything in between to false and move ahead.
if i > b.current && i < b.current+b.length {
// In between current and i need to be zero'd to allow those packets to come in later
for n := b.current + 1; n < i; n++ {
b.bits[n%b.length] = false
}
b.bits[i%b.length] = true
b.current = i
//l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current)
return true
}
// If i is greater than the delta between current and the total length of our bitmap,
// just flip everything in the map and move ahead.
if i >= b.current+b.length {
// The current window loss will be accounted for later, only record the jump as loss up until then
lost := maxInt64(0, int64(i-b.current-b.length))
//TODO: explain this
if b.current == 0 {
lost++
}
for n := range b.bits {
// Don't want to count the first window as a loss
//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
//if b.bits[n] == false {
// lost++
//}
b.bits[n] = false
}
b.lostCounter.Inc(lost)
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
Debug("Receive window")
}
b.bits[i%b.length] = true
b.current = i
return true
}
// Allow for the 0 packet to come in within the first window
if i == 0 && b.firstSeen == false && b.current < b.length {
b.firstSeen = true
b.bits[i%b.length] = true
return true
}
// If i is within the window of current minus length (the total pat window size),
// allow it and flip to true but to NOT change current. We also have to account for the first window
if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
if b.current == i {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
Debug("Receive window")
}
b.dupeCounter.Inc(1)
return false
}
if b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
Debug("Receive window")
}
b.dupeCounter.Inc(1)
return false
}
b.bits[i%b.length] = true
return true
}
// In all other cases, fail and don't change current.
b.outOfWindowCounter.Inc(1)
if l.Level >= logrus.DebugLevel {
l.WithField("accepted", false).
WithField("currentCounter", b.current).
WithField("incomingCounter", i).
WithField("reason", "nonsense").
Debug("Receive window")
}
return false
}
func maxInt64(a, b int64) int64 {
if a > b {
return a
}
return b
}

223
bits_test.go Normal file
View File

@ -0,0 +1,223 @@
package nebula
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestBits(t *testing.T) {
b := NewBits(10)
// make sure it is the right size
assert.Len(t, b.bits, 10)
// This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(1))
u := b.Update(1)
assert.True(t, u)
assert.EqualValues(t, 1, b.current)
g := []bool{false, true, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
// Receive two
assert.True(t, b.Check(2))
u = b.Update(2)
assert.True(t, u)
assert.EqualValues(t, 2, b.current)
g = []bool{false, true, true, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
// Receive two again - it will fail
assert.False(t, b.Check(2))
u = b.Update(2)
assert.False(t, u)
assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element
assert.True(t, b.Check(15))
u = b.Update(15)
assert.True(t, u)
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 14, which is allowed because it is in the window
assert.True(t, b.Check(14))
u = b.Update(14)
assert.True(t, u)
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 5, which is not allowed because it is not in the window
assert.False(t, b.Check(5))
u = b.Update(5)
assert.False(t, u)
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// make sure we handle wrapping around once to the current position
b = NewBits(10)
assert.True(t, b.Update(1))
assert.True(t, b.Update(11))
assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits)
// Walk through a few windows in order
b = NewBits(10)
for i := uint64(0); i <= 100; i++ {
assert.True(t, b.Check(i), "Error while checking %v", i)
assert.True(t, b.Update(i), "Error while updating %v", i)
}
}
func TestBitsDupeCounter(t *testing.T) {
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(1))
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.False(t, b.Update(1))
assert.Equal(t, int64(1), b.dupeCounter.Count())
assert.True(t, b.Update(2))
assert.Equal(t, int64(1), b.dupeCounter.Count())
assert.True(t, b.Update(3))
assert.Equal(t, int64(1), b.dupeCounter.Count())
assert.False(t, b.Update(1))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.Equal(t, int64(2), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func TestBitsOutOfWindowCounter(t *testing.T) {
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(20))
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
assert.True(t, b.Update(21))
assert.True(t, b.Update(22))
assert.True(t, b.Update(23))
assert.True(t, b.Update(24))
assert.True(t, b.Update(25))
assert.True(t, b.Update(26))
assert.True(t, b.Update(27))
assert.True(t, b.Update(28))
assert.True(t, b.Update(29))
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
assert.False(t, b.Update(0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
//tODO: make sure lostcounter doesn't increase in orderly increment
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
}
func TestBitsLostCounter(t *testing.T) {
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
//assert.True(t, b.Update(0))
assert.True(t, b.Update(0))
assert.True(t, b.Update(20))
assert.True(t, b.Update(21))
assert.True(t, b.Update(22))
assert.True(t, b.Update(23))
assert.True(t, b.Update(24))
assert.True(t, b.Update(25))
assert.True(t, b.Update(26))
assert.True(t, b.Update(27))
assert.True(t, b.Update(28))
assert.True(t, b.Update(29))
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
b = NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(0))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(9))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets
assert.True(t, b.Update(10))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 11 will set 1 index, 1 was missed, we should see 1 packet lost
assert.True(t, b.Update(11))
assert.Equal(t, int64(1), b.lostCounter.Count())
// Now let's fill in the window, should end up with 8 lost packets
assert.True(t, b.Update(12))
assert.True(t, b.Update(13))
assert.True(t, b.Update(14))
assert.True(t, b.Update(15))
assert.True(t, b.Update(16))
assert.True(t, b.Update(17))
assert.True(t, b.Update(18))
assert.True(t, b.Update(19))
assert.Equal(t, int64(8), b.lostCounter.Count())
// Jump ahead by a window size
assert.True(t, b.Update(29))
assert.Equal(t, int64(8), b.lostCounter.Count())
// Now lets walk ahead normally through the window, the missed packets should fill in
assert.True(t, b.Update(30))
assert.True(t, b.Update(31))
assert.True(t, b.Update(32))
assert.True(t, b.Update(33))
assert.True(t, b.Update(34))
assert.True(t, b.Update(35))
assert.True(t, b.Update(36))
assert.True(t, b.Update(37))
assert.True(t, b.Update(38))
// 39 packets tracked, 22 seen, 17 lost
assert.Equal(t, int64(17), b.lostCounter.Count())
// Jump ahead by 2 windows, should have recording 1 full window missing
assert.True(t, b.Update(58))
assert.Equal(t, int64(27), b.lostCounter.Count())
// Now lets walk ahead normally through the window, the missed packets should fill in from this window
assert.True(t, b.Update(59))
assert.True(t, b.Update(60))
assert.True(t, b.Update(61))
assert.True(t, b.Update(62))
assert.True(t, b.Update(63))
assert.True(t, b.Update(64))
assert.True(t, b.Update(65))
assert.True(t, b.Update(66))
assert.True(t, b.Update(67))
// 68 packets tracked, 32 seen, 36 missed
assert.Equal(t, int64(36), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func BenchmarkBits(b *testing.B) {
z := NewBits(10)
for n := 0; n < b.N; n++ {
for i, _ := range z.bits {
z.bits[i] = true
}
for i, _ := range z.bits {
z.bits[i] = false
}
}
}

159
cert.go Normal file
View File

@ -0,0 +1,159 @@
package nebula
import (
"errors"
"fmt"
"io/ioutil"
"strings"
"time"
"github.com/slackhq/nebula/cert"
)
var trustedCAs *cert.NebulaCAPool
type CertState struct {
certificate *cert.NebulaCertificate
rawCertificate []byte
rawCertificateNoKey []byte
publicKey []byte
privateKey []byte
}
func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) {
// Marshal the certificate to ensure it is valid
rawCertificate, err := certificate.Marshal()
if err != nil {
return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
}
publicKey := certificate.Details.PublicKey
cs := &CertState{
rawCertificate: rawCertificate,
certificate: certificate, // PublicKey has been set to nil above
privateKey: privateKey,
publicKey: publicKey,
}
cs.certificate.Details.PublicKey = nil
rawCertNoKey, err := cs.certificate.Marshal()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
}
cs.rawCertificateNoKey = rawCertNoKey
// put public key back
cs.certificate.Details.PublicKey = cs.publicKey
return cs, nil
}
func NewCertStateFromConfig(c *Config) (*CertState, error) {
var pemPrivateKey []byte
var err error
privPathOrPEM := c.GetString("pki.key", "")
if privPathOrPEM == "" {
// Support backwards compat with the old x509
//TODO: remove after this is rolled out everywhere - NB 2018/02/23
privPathOrPEM = c.GetString("x509.key", "")
}
if privPathOrPEM == "" {
return nil, errors.New("no pki.key path or PEM data provided")
}
if strings.Contains(privPathOrPEM, "-----BEGIN") {
pemPrivateKey = []byte(privPathOrPEM)
privPathOrPEM = "<inline>"
} else {
pemPrivateKey, err = ioutil.ReadFile(privPathOrPEM)
if err != nil {
return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
}
}
rawKey, _, err := cert.UnmarshalX25519PrivateKey(pemPrivateKey)
if err != nil {
return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
var rawCert []byte
pubPathOrPEM := c.GetString("pki.cert", "")
if pubPathOrPEM == "" {
// Support backwards compat with the old x509
//TODO: remove after this is rolled out everywhere - NB 2018/02/23
pubPathOrPEM = c.GetString("x509.cert", "")
}
if pubPathOrPEM == "" {
return nil, errors.New("no pki.cert path or PEM data provided")
}
if strings.Contains(pubPathOrPEM, "-----BEGIN") {
rawCert = []byte(pubPathOrPEM)
pubPathOrPEM = "<inline>"
} else {
rawCert, err = ioutil.ReadFile(pubPathOrPEM)
if err != nil {
return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err)
}
}
nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
if err != nil {
return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
}
if nebulaCert.Expired(time.Now()) {
return nil, fmt.Errorf("nebula certificate for this host is expired")
}
if len(nebulaCert.Details.Ips) == 0 {
return nil, fmt.Errorf("no IPs encoded in certificate")
}
if err = nebulaCert.VerifyPrivateKey(rawKey); err != nil {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
}
return NewCertState(nebulaCert, rawKey)
}
func loadCAFromConfig(c *Config) (*cert.NebulaCAPool, error) {
var rawCA []byte
var err error
caPathOrPEM := c.GetString("pki.ca", "")
if caPathOrPEM == "" {
// Support backwards compat with the old x509
//TODO: remove after this is rolled out everywhere - NB 2018/02/23
caPathOrPEM = c.GetString("x509.ca", "")
}
if caPathOrPEM == "" {
return nil, errors.New("no pki.ca path or PEM data provided")
}
if strings.Contains(caPathOrPEM, "-----BEGIN") {
rawCA = []byte(caPathOrPEM)
caPathOrPEM = "<inline>"
} else {
rawCA, err = ioutil.ReadFile(caPathOrPEM)
if err != nil {
return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
}
}
CAs, err := cert.NewCAPoolFromBytes(rawCA)
if err != nil {
return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
}
// pki.blacklist entered the scene at about the same time we aliased x509 to pki, not supporting backwards compat
for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) {
l.WithField("fingerprint", fp).Infof("Blacklisting cert")
CAs.BlacklistFingerprint(fp)
}
return CAs, nil
}

9
cert/Makefile Normal file
View File

@ -0,0 +1,9 @@
GO111MODULE = on
export GO111MODULE
cert.pb.go: cert.proto .FORCE
go build github.com/golang/protobuf/protoc-gen-go
PATH="$(PWD):$(PATH)" protoc --go_out=. $<
rm protoc-gen-go
.FORCE:

15
cert/README.md Normal file
View File

@ -0,0 +1,15 @@
## `cert`
This is a library for interacting with `nebula` style certificates and authorities.
A `protobuf` definition of the certificate format is also included
### Compiling the protobuf definition
Make sure you have `protoc` installed.
To compile for `go` with the same version of protobuf specified in go.mod:
```bash
make
```

120
cert/ca.go Normal file
View File

@ -0,0 +1,120 @@
package cert
import (
"fmt"
"strings"
"time"
)
type NebulaCAPool struct {
CAs map[string]*NebulaCertificate
certBlacklist map[string]struct{}
}
// NewCAPool creates a CAPool
func NewCAPool() *NebulaCAPool {
ca := NebulaCAPool{
CAs: make(map[string]*NebulaCertificate),
certBlacklist: make(map[string]struct{}),
}
return &ca
}
func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, error) {
pool := NewCAPool()
var err error
for {
caPEMs, err = pool.AddCACertificate(caPEMs)
if err != nil {
return nil, err
}
if caPEMs == nil || len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
break
}
}
return pool, nil
}
// AddCACertificate verifies a Nebula CA certificate and adds it to the pool
// Only the first pem encoded object will be consumed, any remaining bytes are returned.
// Parsed certificates will be verified and must be a CA
func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) {
c, pemBytes, err := UnmarshalNebulaCertificateFromPEM(pemBytes)
if err != nil {
return pemBytes, err
}
if !c.Details.IsCA {
return pemBytes, fmt.Errorf("provided certificate was not a CA; %s", c.Details.Name)
}
if !c.CheckSignature(c.Details.PublicKey) {
return pemBytes, fmt.Errorf("provided certificate was not self signed; %s", c.Details.Name)
}
if c.Expired(time.Now()) {
return pemBytes, fmt.Errorf("provided CA certificate is expired; %s", c.Details.Name)
}
sum, err := c.Sha256Sum()
if err != nil {
return pemBytes, fmt.Errorf("could not calculate shasum for provided CA; error: %s; %s", err, c.Details.Name)
}
ncp.CAs[sum] = c
return pemBytes, nil
}
// BlacklistFingerprint adds a cert fingerprint to the blacklist
func (ncp *NebulaCAPool) BlacklistFingerprint(f string) {
ncp.certBlacklist[f] = struct{}{}
}
// ResetCertBlacklist removes all previously blacklisted cert fingerprints
func (ncp *NebulaCAPool) ResetCertBlacklist() {
ncp.certBlacklist = make(map[string]struct{})
}
// IsBlacklisted returns true if the fingerprint fails to generate or has been explicitly blacklisted
func (ncp *NebulaCAPool) IsBlacklisted(c *NebulaCertificate) bool {
h, err := c.Sha256Sum()
if err != nil {
return true
}
if _, ok := ncp.certBlacklist[h]; ok {
return true
}
return false
}
// GetCAForCert attempts to return the signing certificate for the provided certificate.
// No signature validation is performed
func (ncp *NebulaCAPool) GetCAForCert(c *NebulaCertificate) (*NebulaCertificate, error) {
if c.Details.Issuer == "" {
return nil, fmt.Errorf("no issuer in certificate")
}
signer, ok := ncp.CAs[c.Details.Issuer]
if ok {
return signer, nil
}
return nil, fmt.Errorf("could not find ca for the certificate")
}
// GetFingerprints returns an array of trusted CA fingerprints
func (ncp *NebulaCAPool) GetFingerprints() []string {
fp := make([]string, len(ncp.CAs))
i := 0
for k := range ncp.CAs {
fp[i] = k
i++
}
return fp
}

445
cert/cert.go Normal file
View File

@ -0,0 +1,445 @@
package cert
import (
"crypto"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/pem"
"fmt"
"net"
"time"
"bytes"
"encoding/json"
"github.com/golang/protobuf/proto"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
)
const publicKeyLen = 32
const (
CertBanner = "NEBULA CERTIFICATE"
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
)
type NebulaCertificate struct {
Details NebulaCertificateDetails
Signature []byte
}
type NebulaCertificateDetails struct {
Name string
Ips []*net.IPNet
Subnets []*net.IPNet
Groups []string
NotBefore time.Time
NotAfter time.Time
PublicKey []byte
IsCA bool
Issuer string
// Map of groups for faster lookup
InvertedGroups map[string]struct{}
}
type m map[string]interface{}
// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert
func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) {
if len(b) == 0 {
return nil, fmt.Errorf("nil byte array")
}
var rc RawNebulaCertificate
err := proto.Unmarshal(b, &rc)
if err != nil {
return nil, err
}
if len(rc.Details.Ips)%2 != 0 {
return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
}
if len(rc.Details.Subnets)%2 != 0 {
return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found")
}
nc := NebulaCertificate{
Details: NebulaCertificateDetails{
Name: rc.Details.Name,
Groups: make([]string, len(rc.Details.Groups)),
Ips: make([]*net.IPNet, len(rc.Details.Ips)/2),
Subnets: make([]*net.IPNet, len(rc.Details.Subnets)/2),
NotBefore: time.Unix(rc.Details.NotBefore, 0),
NotAfter: time.Unix(rc.Details.NotAfter, 0),
PublicKey: make([]byte, len(rc.Details.PublicKey)),
IsCA: rc.Details.IsCA,
InvertedGroups: make(map[string]struct{}),
},
Signature: make([]byte, len(rc.Signature)),
}
copy(nc.Signature, rc.Signature)
copy(nc.Details.Groups, rc.Details.Groups)
nc.Details.Issuer = hex.EncodeToString(rc.Details.Issuer)
if len(rc.Details.PublicKey) < publicKeyLen {
return nil, fmt.Errorf("Public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey))
}
copy(nc.Details.PublicKey, rc.Details.PublicKey)
for i, rawIp := range rc.Details.Ips {
if i%2 == 0 {
nc.Details.Ips[i/2] = &net.IPNet{IP: int2ip(rawIp)}
} else {
nc.Details.Ips[i/2].Mask = net.IPMask(int2ip(rawIp))
}
}
for i, rawIp := range rc.Details.Subnets {
if i%2 == 0 {
nc.Details.Subnets[i/2] = &net.IPNet{IP: int2ip(rawIp)}
} else {
nc.Details.Subnets[i/2].Mask = net.IPMask(int2ip(rawIp))
}
}
for _, g := range rc.Details.Groups {
nc.Details.InvertedGroups[g] = struct{}{}
}
return &nc, nil
}
// UnmarshalNebulaCertificateFromPEM will unmarshal the first pem block in a byte array, returning any non consumed data
// or an error on failure
func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
p, r := pem.Decode(b)
if p == nil {
return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
}
nc, err := UnmarshalNebulaCertificate(p.Bytes)
return nc, r, err
}
// MarshalX25519PrivateKey is a simple helper to PEM encode an X25519 private key
func MarshalX25519PrivateKey(b []byte) []byte {
return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
}
// MarshalEd25519PrivateKey is a simple helper to PEM encode an Ed25519 private key
func MarshalEd25519PrivateKey(key ed25519.PrivateKey) []byte {
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: key})
}
// UnmarshalX25519PrivateKey will try to pem decode an X25519 private key, returning any other bytes b
// or an error on failure
func UnmarshalX25519PrivateKey(b []byte) ([]byte, []byte, error) {
k, r := pem.Decode(b)
if k == nil {
return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
}
if k.Type != X25519PrivateKeyBanner {
return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 private key banner")
}
if len(k.Bytes) != publicKeyLen {
return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 private key")
}
return k.Bytes, r, nil
}
// UnmarshalEd25519PrivateKey will try to pem decode an Ed25519 private key, returning any other bytes b
// or an error on failure
func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) {
k, r := pem.Decode(b)
if k == nil {
return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
}
if k.Type != Ed25519PrivateKeyBanner {
return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 private key banner")
}
if len(k.Bytes) != ed25519.PrivateKeySize {
return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
}
return k.Bytes, r, nil
}
// MarshalX25519PublicKey is a simple helper to PEM encode an X25519 public key
func MarshalX25519PublicKey(b []byte) []byte {
return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
}
// MarshalEd25519PublicKey is a simple helper to PEM encode an Ed25519 public key
func MarshalEd25519PublicKey(key ed25519.PublicKey) []byte {
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: key})
}
// UnmarshalX25519PublicKey will try to pem decode an X25519 public key, returning any other bytes b
// or an error on failure
func UnmarshalX25519PublicKey(b []byte) ([]byte, []byte, error) {
k, r := pem.Decode(b)
if k == nil {
return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
}
if k.Type != X25519PublicKeyBanner {
return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 public key banner")
}
if len(k.Bytes) != publicKeyLen {
return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 public key")
}
return k.Bytes, r, nil
}
// UnmarshalEd25519PublicKey will try to pem decode an Ed25519 public key, returning any other bytes b
// or an error on failure
func UnmarshalEd25519PublicKey(b []byte) (ed25519.PublicKey, []byte, error) {
k, r := pem.Decode(b)
if k == nil {
return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
}
if k.Type != Ed25519PublicKeyBanner {
return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 public key banner")
}
if len(k.Bytes) != ed25519.PublicKeySize {
return nil, r, fmt.Errorf("key was not 32 bytes, is invalid ed25519 public key")
}
return k.Bytes, r, nil
}
// Sign signs a nebula cert with the provided private key
func (nc *NebulaCertificate) Sign(key ed25519.PrivateKey) error {
b, err := proto.Marshal(nc.getRawDetails())
if err != nil {
return err
}
sig, err := key.Sign(rand.Reader, b, crypto.Hash(0))
if err != nil {
return err
}
nc.Signature = sig
return nil
}
// CheckSignature verifies the signature against the provided public key
func (nc *NebulaCertificate) CheckSignature(key ed25519.PublicKey) bool {
b, err := proto.Marshal(nc.getRawDetails())
if err != nil {
return false
}
return ed25519.Verify(key, b, nc.Signature)
}
// Expired will return true if the nebula cert is too young or too old compared to the provided time, otherwise false
func (nc *NebulaCertificate) Expired(t time.Time) bool {
return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t)
}
// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blacklist, etc)
func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) {
if ncp.IsBlacklisted(nc) {
return false, fmt.Errorf("certificate has been blacklisted")
}
signer, err := ncp.GetCAForCert(nc)
if err != nil {
return false, err
}
if signer.Expired(t) {
return false, fmt.Errorf("root certificate is expired")
}
if nc.Expired(t) {
return false, fmt.Errorf("certificate is expired")
}
if !nc.CheckSignature(signer.Details.PublicKey) {
return false, fmt.Errorf("certificate signature did not match")
}
// If the signer has a limited set of groups make sure the cert only contains a subset
if len(signer.Details.InvertedGroups) > 0 {
for _, g := range nc.Details.Groups {
if _, ok := signer.Details.InvertedGroups[g]; !ok {
return false, fmt.Errorf("certificate contained a group not present on the signing ca; %s", g)
}
}
}
return true, nil
}
// VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match
func (nc *NebulaCertificate) VerifyPrivateKey(key []byte) error {
var dst, key32 [32]byte
copy(key32[:], key)
curve25519.ScalarBaseMult(&dst, &key32)
if !bytes.Equal(dst[:], nc.Details.PublicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
return nil
}
// String will return a pretty printed representation of a nebula cert
func (nc *NebulaCertificate) String() string {
if nc == nil {
return "NebulaCertificate {}\n"
}
s := "NebulaCertificate {\n"
s += "\tDetails {\n"
s += fmt.Sprintf("\t\tName: %v\n", nc.Details.Name)
if len(nc.Details.Ips) > 0 {
s += "\t\tIps: [\n"
for _, ip := range nc.Details.Ips {
s += fmt.Sprintf("\t\t\t%v\n", ip.String())
}
s += "\t\t]\n"
} else {
s += "\t\tIps: []\n"
}
if len(nc.Details.Subnets) > 0 {
s += "\t\tSubnets: [\n"
for _, ip := range nc.Details.Subnets {
s += fmt.Sprintf("\t\t\t%v\n", ip.String())
}
s += "\t\t]\n"
} else {
s += "\t\tSubnets: []\n"
}
if len(nc.Details.Groups) > 0 {
s += "\t\tGroups: [\n"
for _, g := range nc.Details.Groups {
s += fmt.Sprintf("\t\t\t\"%v\"\n", g)
}
s += "\t\t]\n"
} else {
s += "\t\tGroups: []\n"
}
s += fmt.Sprintf("\t\tNot before: %v\n", nc.Details.NotBefore)
s += fmt.Sprintf("\t\tNot After: %v\n", nc.Details.NotAfter)
s += fmt.Sprintf("\t\tIs CA: %v\n", nc.Details.IsCA)
s += fmt.Sprintf("\t\tIssuer: %s\n", nc.Details.Issuer)
s += fmt.Sprintf("\t\tPublic key: %x\n", nc.Details.PublicKey)
s += "\t}\n"
fp, err := nc.Sha256Sum()
if err == nil {
s += fmt.Sprintf("\tFingerprint: %s\n", fp)
}
s += fmt.Sprintf("\tSignature: %x\n", nc.Signature)
s += "}"
return s
}
// getRawDetails marshals the raw details into protobuf ready struct
func (nc *NebulaCertificate) getRawDetails() *RawNebulaCertificateDetails {
rd := &RawNebulaCertificateDetails{
Name: nc.Details.Name,
Groups: nc.Details.Groups,
NotBefore: nc.Details.NotBefore.Unix(),
NotAfter: nc.Details.NotAfter.Unix(),
PublicKey: make([]byte, len(nc.Details.PublicKey)),
IsCA: nc.Details.IsCA,
}
for _, ipNet := range nc.Details.Ips {
rd.Ips = append(rd.Ips, ip2int(ipNet.IP), ip2int(ipNet.Mask))
}
for _, ipNet := range nc.Details.Subnets {
rd.Subnets = append(rd.Subnets, ip2int(ipNet.IP), ip2int(ipNet.Mask))
}
copy(rd.PublicKey, nc.Details.PublicKey[:])
// I know, this is terrible
rd.Issuer, _ = hex.DecodeString(nc.Details.Issuer)
return rd
}
// Marshal will marshal a nebula cert into a protobuf byte array
func (nc *NebulaCertificate) Marshal() ([]byte, error) {
rc := RawNebulaCertificate{
Details: nc.getRawDetails(),
Signature: nc.Signature,
}
return proto.Marshal(&rc)
}
// MarshalToPEM will marshal a nebula cert into a protobuf byte array and pem encode the result
func (nc *NebulaCertificate) MarshalToPEM() ([]byte, error) {
b, err := nc.Marshal()
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: CertBanner, Bytes: b}), nil
}
// Sha256Sum calculates a sha-256 sum of the marshaled certificate
func (nc *NebulaCertificate) Sha256Sum() (string, error) {
b, err := nc.Marshal()
if err != nil {
return "", err
}
sum := sha256.Sum256(b)
return hex.EncodeToString(sum[:]), nil
}
func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
toString := func(ips []*net.IPNet) []string {
s := []string{}
for _, ip := range ips {
s = append(s, ip.String())
}
return s
}
fp, _ := nc.Sha256Sum()
jc := m{
"details": m{
"name": nc.Details.Name,
"ips": toString(nc.Details.Ips),
"subnets": toString(nc.Details.Subnets),
"groups": nc.Details.Groups,
"notBefore": nc.Details.NotBefore,
"notAfter": nc.Details.NotAfter,
"publicKey": fmt.Sprintf("%x", nc.Details.PublicKey),
"isCa": nc.Details.IsCA,
"issuer": nc.Details.Issuer,
},
"fingerprint": fp,
"signature": fmt.Sprintf("%x", nc.Signature),
}
return json.Marshal(jc)
}
func ip2int(ip []byte) uint32 {
if len(ip) == 16 {
return binary.BigEndian.Uint32(ip[12:16])
}
return binary.BigEndian.Uint32(ip)
}
func int2ip(nn uint32) net.IP {
ip := make(net.IP, net.IPv4len)
binary.BigEndian.PutUint32(ip, nn)
return ip
}

202
cert/cert.pb.go Normal file
View File

@ -0,0 +1,202 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: cert.proto
package cert
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type RawNebulaCertificate struct {
Details *RawNebulaCertificateDetails `protobuf:"bytes,1,opt,name=Details,json=details,proto3" json:"Details,omitempty"`
Signature []byte `protobuf:"bytes,2,opt,name=Signature,json=signature,proto3" json:"Signature,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *RawNebulaCertificate) Reset() { *m = RawNebulaCertificate{} }
func (m *RawNebulaCertificate) String() string { return proto.CompactTextString(m) }
func (*RawNebulaCertificate) ProtoMessage() {}
func (*RawNebulaCertificate) Descriptor() ([]byte, []int) {
return fileDescriptor_a142e29cbef9b1cf, []int{0}
}
func (m *RawNebulaCertificate) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_RawNebulaCertificate.Unmarshal(m, b)
}
func (m *RawNebulaCertificate) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_RawNebulaCertificate.Marshal(b, m, deterministic)
}
func (m *RawNebulaCertificate) XXX_Merge(src proto.Message) {
xxx_messageInfo_RawNebulaCertificate.Merge(m, src)
}
func (m *RawNebulaCertificate) XXX_Size() int {
return xxx_messageInfo_RawNebulaCertificate.Size(m)
}
func (m *RawNebulaCertificate) XXX_DiscardUnknown() {
xxx_messageInfo_RawNebulaCertificate.DiscardUnknown(m)
}
var xxx_messageInfo_RawNebulaCertificate proto.InternalMessageInfo
func (m *RawNebulaCertificate) GetDetails() *RawNebulaCertificateDetails {
if m != nil {
return m.Details
}
return nil
}
func (m *RawNebulaCertificate) GetSignature() []byte {
if m != nil {
return m.Signature
}
return nil
}
type RawNebulaCertificateDetails struct {
Name string `protobuf:"bytes,1,opt,name=Name,json=name,proto3" json:"Name,omitempty"`
// Ips and Subnets are in big endian 32 bit pairs, 1st the ip, 2nd the mask
Ips []uint32 `protobuf:"varint,2,rep,packed,name=Ips,json=ips,proto3" json:"Ips,omitempty"`
Subnets []uint32 `protobuf:"varint,3,rep,packed,name=Subnets,json=subnets,proto3" json:"Subnets,omitempty"`
Groups []string `protobuf:"bytes,4,rep,name=Groups,json=groups,proto3" json:"Groups,omitempty"`
NotBefore int64 `protobuf:"varint,5,opt,name=NotBefore,json=notBefore,proto3" json:"NotBefore,omitempty"`
NotAfter int64 `protobuf:"varint,6,opt,name=NotAfter,json=notAfter,proto3" json:"NotAfter,omitempty"`
PublicKey []byte `protobuf:"bytes,7,opt,name=PublicKey,json=publicKey,proto3" json:"PublicKey,omitempty"`
IsCA bool `protobuf:"varint,8,opt,name=IsCA,json=isCA,proto3" json:"IsCA,omitempty"`
// sha-256 of the issuer certificate, if this field is blank the cert is self-signed
Issuer []byte `protobuf:"bytes,9,opt,name=Issuer,json=issuer,proto3" json:"Issuer,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *RawNebulaCertificateDetails) Reset() { *m = RawNebulaCertificateDetails{} }
func (m *RawNebulaCertificateDetails) String() string { return proto.CompactTextString(m) }
func (*RawNebulaCertificateDetails) ProtoMessage() {}
func (*RawNebulaCertificateDetails) Descriptor() ([]byte, []int) {
return fileDescriptor_a142e29cbef9b1cf, []int{1}
}
func (m *RawNebulaCertificateDetails) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_RawNebulaCertificateDetails.Unmarshal(m, b)
}
func (m *RawNebulaCertificateDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_RawNebulaCertificateDetails.Marshal(b, m, deterministic)
}
func (m *RawNebulaCertificateDetails) XXX_Merge(src proto.Message) {
xxx_messageInfo_RawNebulaCertificateDetails.Merge(m, src)
}
func (m *RawNebulaCertificateDetails) XXX_Size() int {
return xxx_messageInfo_RawNebulaCertificateDetails.Size(m)
}
func (m *RawNebulaCertificateDetails) XXX_DiscardUnknown() {
xxx_messageInfo_RawNebulaCertificateDetails.DiscardUnknown(m)
}
var xxx_messageInfo_RawNebulaCertificateDetails proto.InternalMessageInfo
func (m *RawNebulaCertificateDetails) GetName() string {
if m != nil {
return m.Name
}
return ""
}
func (m *RawNebulaCertificateDetails) GetIps() []uint32 {
if m != nil {
return m.Ips
}
return nil
}
func (m *RawNebulaCertificateDetails) GetSubnets() []uint32 {
if m != nil {
return m.Subnets
}
return nil
}
func (m *RawNebulaCertificateDetails) GetGroups() []string {
if m != nil {
return m.Groups
}
return nil
}
func (m *RawNebulaCertificateDetails) GetNotBefore() int64 {
if m != nil {
return m.NotBefore
}
return 0
}
func (m *RawNebulaCertificateDetails) GetNotAfter() int64 {
if m != nil {
return m.NotAfter
}
return 0
}
func (m *RawNebulaCertificateDetails) GetPublicKey() []byte {
if m != nil {
return m.PublicKey
}
return nil
}
func (m *RawNebulaCertificateDetails) GetIsCA() bool {
if m != nil {
return m.IsCA
}
return false
}
func (m *RawNebulaCertificateDetails) GetIssuer() []byte {
if m != nil {
return m.Issuer
}
return nil
}
func init() {
proto.RegisterType((*RawNebulaCertificate)(nil), "cert.RawNebulaCertificate")
proto.RegisterType((*RawNebulaCertificateDetails)(nil), "cert.RawNebulaCertificateDetails")
}
func init() { proto.RegisterFile("cert.proto", fileDescriptor_a142e29cbef9b1cf) }
var fileDescriptor_a142e29cbef9b1cf = []byte{
// 279 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x90, 0xcf, 0x4a, 0xf4, 0x30,
0x14, 0xc5, 0xc9, 0xa4, 0x5f, 0xdb, 0xe4, 0x53, 0x90, 0x20, 0x12, 0xd4, 0x45, 0x9c, 0x55, 0x56,
0xb3, 0xd0, 0xa5, 0xab, 0x71, 0x04, 0x29, 0x42, 0x91, 0xcc, 0x13, 0xa4, 0xf5, 0x76, 0x08, 0x74,
0x9a, 0x9a, 0x3f, 0x88, 0x8f, 0xee, 0x4e, 0x9a, 0x4e, 0x77, 0xe2, 0xee, 0x9e, 0x5f, 0xce, 0x49,
0x4e, 0x2e, 0xa5, 0x2d, 0xb8, 0xb0, 0x19, 0x9d, 0x0d, 0x96, 0x65, 0xd3, 0xbc, 0xfe, 0xa0, 0x97,
0x4a, 0x7f, 0xd6, 0xd0, 0xc4, 0x5e, 0xef, 0xc0, 0x05, 0xd3, 0x99, 0x56, 0x07, 0x60, 0x8f, 0xb4,
0x78, 0x86, 0xa0, 0x4d, 0xef, 0x39, 0x12, 0x48, 0xfe, 0xbf, 0xbf, 0xdb, 0xa4, 0xec, 0x6f, 0xe6,
0x93, 0x51, 0x15, 0xef, 0xf3, 0xc0, 0x6e, 0x29, 0xd9, 0x9b, 0xc3, 0xa0, 0x43, 0x74, 0xc0, 0x57,
0x02, 0xc9, 0x33, 0x45, 0xfc, 0x02, 0xd6, 0xdf, 0x88, 0xde, 0xfc, 0x71, 0x0d, 0x63, 0x34, 0xab,
0xf5, 0x11, 0xd2, 0xbb, 0x44, 0x65, 0x83, 0x3e, 0x02, 0xbb, 0xa0, 0xb8, 0x1a, 0x3d, 0x5f, 0x09,
0x2c, 0xcf, 0x15, 0x36, 0xa3, 0x67, 0x9c, 0x16, 0xfb, 0xd8, 0x0c, 0x10, 0x3c, 0xc7, 0x89, 0x16,
0x7e, 0x96, 0xec, 0x8a, 0xe6, 0x2f, 0xce, 0xc6, 0xd1, 0xf3, 0x4c, 0x60, 0x49, 0x54, 0x7e, 0x48,
0x6a, 0x6a, 0x55, 0xdb, 0xf0, 0x04, 0x9d, 0x75, 0xc0, 0xff, 0x09, 0x24, 0xb1, 0x22, 0xc3, 0x02,
0xd8, 0x35, 0x2d, 0x6b, 0x1b, 0xb6, 0x5d, 0x00, 0xc7, 0xf3, 0x74, 0x58, 0x0e, 0x27, 0x3d, 0x25,
0xdf, 0x62, 0xd3, 0x9b, 0xf6, 0x15, 0xbe, 0x78, 0x31, 0xff, 0x67, 0x5c, 0xc0, 0xd4, 0xb7, 0xf2,
0xbb, 0x2d, 0x2f, 0x05, 0x92, 0xa5, 0xca, 0x8c, 0xdf, 0x6d, 0xa7, 0x0e, 0x95, 0xf7, 0x11, 0x1c,
0x27, 0xc9, 0x9e, 0x9b, 0xa4, 0x9a, 0x3c, 0xed, 0xfe, 0xe1, 0x27, 0x00, 0x00, 0xff, 0xff, 0x2c,
0xe3, 0x08, 0x37, 0x89, 0x01, 0x00, 0x00,
}

27
cert/cert.proto Normal file
View File

@ -0,0 +1,27 @@
syntax = "proto3";
package cert;
//import "google/protobuf/timestamp.proto";
message RawNebulaCertificate {
RawNebulaCertificateDetails Details = 1;
bytes Signature = 2;
}
message RawNebulaCertificateDetails {
string Name = 1;
// Ips and Subnets are in big endian 32 bit pairs, 1st the ip, 2nd the mask
repeated uint32 Ips = 2;
repeated uint32 Subnets = 3;
repeated string Groups = 4;
int64 NotBefore = 5;
int64 NotAfter = 6;
bytes PublicKey = 7;
bool IsCA = 8;
// sha-256 of the issuer certificate, if this field is blank the cert is self-signed
bytes Issuer = 9;
}

373
cert/cert_test.go Normal file
View File

@ -0,0 +1,373 @@
package cert
import (
"crypto/rand"
"fmt"
"io"
"net"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
)
func TestMarshalingNebulaCertificate(t *testing.T) {
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := NebulaCertificate{
Details: NebulaCertificateDetails{
Name: "testing",
Ips: []*net.IPNet{
{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
},
Subnets: []*net.IPNet{
{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: before,
NotAfter: after,
PublicKey: pubKey,
IsCA: false,
Issuer: "1234567890abcedfghij1234567890ab",
},
Signature: []byte("1234567890abcedfghij1234567890ab"),
}
b, err := nc.Marshal()
assert.Nil(t, err)
t.Log("Cert size:", len(b))
nc2, err := UnmarshalNebulaCertificate(b)
assert.Nil(t, err)
assert.Equal(t, nc.Signature, nc2.Signature)
assert.Equal(t, nc.Details.Name, nc2.Details.Name)
assert.Equal(t, nc.Details.NotBefore, nc2.Details.NotBefore)
assert.Equal(t, nc.Details.NotAfter, nc2.Details.NotAfter)
assert.Equal(t, nc.Details.PublicKey, nc2.Details.PublicKey)
assert.Equal(t, nc.Details.IsCA, nc2.Details.IsCA)
// IP byte arrays can be 4 or 16 in length so we have to go this route
assert.Equal(t, len(nc.Details.Ips), len(nc2.Details.Ips))
for i, wIp := range nc.Details.Ips {
assert.Equal(t, wIp.String(), nc2.Details.Ips[i].String())
}
assert.Equal(t, len(nc.Details.Subnets), len(nc2.Details.Subnets))
for i, wIp := range nc.Details.Subnets {
assert.Equal(t, wIp.String(), nc2.Details.Subnets[i].String())
}
assert.EqualValues(t, nc.Details.Groups, nc2.Details.Groups)
}
func TestNebulaCertificate_Sign(t *testing.T) {
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := NebulaCertificate{
Details: NebulaCertificateDetails{
Name: "testing",
Ips: []*net.IPNet{
{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
},
Subnets: []*net.IPNet{
{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: before,
NotAfter: after,
PublicKey: pubKey,
IsCA: false,
Issuer: "1234567890abcedfghij1234567890ab",
},
}
pub, priv, err := ed25519.GenerateKey(rand.Reader)
assert.Nil(t, err)
assert.False(t, nc.CheckSignature(pub))
assert.Nil(t, nc.Sign(priv))
assert.True(t, nc.CheckSignature(pub))
b, err := nc.Marshal()
assert.Nil(t, err)
t.Log("Cert size:", len(b))
}
func TestNebulaCertificate_Expired(t *testing.T) {
nc := NebulaCertificate{
Details: NebulaCertificateDetails{
NotBefore: time.Now().Add(time.Second * -60).Round(time.Second),
NotAfter: time.Now().Add(time.Second * 60).Round(time.Second),
},
}
assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
assert.False(t, nc.Expired(time.Now()))
}
func TestNebulaCertificate_MarshalJSON(t *testing.T) {
time.Local = time.UTC
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := NebulaCertificate{
Details: NebulaCertificateDetails{
Name: "testing",
Ips: []*net.IPNet{
{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
},
Subnets: []*net.IPNet{
{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
NotAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
PublicKey: pubKey,
IsCA: false,
Issuer: "1234567890abcedfghij1234567890ab",
},
Signature: []byte("1234567890abcedfghij1234567890ab"),
}
b, err := nc.MarshalJSON()
assert.Nil(t, err)
assert.Equal(
t,
"{\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\",\"10.1.1.3/ff00ff00\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.1/ff00ff00\",\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"26cb1c30ad7872c804c166b5150fa372f437aa3856b04edb4334b4470ec728e4\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}",
string(b),
)
}
func TestNebulaCertificate_Verify(t *testing.T) {
ca, _, caKey, err := newTestCaCert()
assert.Nil(t, err)
c, _, _, err := newTestCert(ca, caKey)
assert.Nil(t, err)
h, err := ca.Sha256Sum()
assert.Nil(t, err)
caPool := NewCAPool()
caPool.CAs[h] = ca
f, err := c.Sha256Sum()
assert.Nil(t, err)
caPool.BlacklistFingerprint(f)
v, err := c.Verify(time.Now(), caPool)
assert.False(t, v)
assert.EqualError(t, err, "certificate has been blacklisted")
caPool.ResetCertBlacklist()
v, err = c.Verify(time.Now(), caPool)
assert.True(t, v)
assert.Nil(t, err)
v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool)
assert.False(t, v)
assert.EqualError(t, err, "root certificate is expired")
}
func TestNebulaVerifyPrivateKey(t *testing.T) {
ca, _, caKey, err := newTestCaCert()
assert.Nil(t, err)
c, _, priv, err := newTestCert(ca, caKey)
err = c.VerifyPrivateKey(priv)
assert.Nil(t, err)
_, priv2 := x25519Keypair()
err = c.VerifyPrivateKey(priv2)
assert.NotNil(t, err)
}
func TestNewCAPoolFromBytes(t *testing.T) {
noNewLines := `
# Current provisional, Remove once everything moves over to the real root.
-----BEGIN NEBULA CERTIFICATE-----
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
-----END NEBULA CERTIFICATE-----
# root-ca01
-----BEGIN NEBULA CERTIFICATE-----
CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG
BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF
-----END NEBULA CERTIFICATE-----
`
withNewLines := `
# Current provisional, Remove once everything moves over to the real root.
-----BEGIN NEBULA CERTIFICATE-----
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
-----END NEBULA CERTIFICATE-----
# root-ca01
-----BEGIN NEBULA CERTIFICATE-----
CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG
BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF
-----END NEBULA CERTIFICATE-----
`
rootCA := NebulaCertificate{
Details: NebulaCertificateDetails{
Name: "nebula root ca",
},
}
rootCA01 := NebulaCertificate{
Details: NebulaCertificateDetails{
Name: "nebula root ca 01",
},
}
p, err := NewCAPoolFromBytes([]byte(noNewLines))
assert.Nil(t, err)
assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
pp, err := NewCAPoolFromBytes([]byte(withNewLines))
assert.Nil(t, err)
assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
}
// Ensure that upgrading the protobuf library does not change how certificates
// are marshalled, since this would break signature verification
func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC)
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := NebulaCertificate{
Details: NebulaCertificateDetails{
Name: "testing",
Ips: []*net.IPNet{
{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
},
Subnets: []*net.IPNet{
{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: before,
NotAfter: after,
PublicKey: pubKey,
IsCA: false,
Issuer: "1234567890abcedfghij1234567890ab",
},
Signature: []byte("1234567890abcedfghij1234567890ab"),
}
b, err := nc.Marshal()
assert.Nil(t, err)
t.Log("Cert size:", len(b))
assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
b, err = proto.Marshal(nc.getRawDetails())
assert.Nil(t, err)
t.Log("Raw cert size:", len(b))
assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
}
func newTestCaCert() (*NebulaCertificate, []byte, []byte, error) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
nc := &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: "test ca",
NotBefore: before,
NotAfter: after,
PublicKey: pub,
IsCA: true,
},
}
err = nc.Sign(priv)
if err != nil {
return nil, nil, nil, err
}
return nc, pub, priv, nil
}
func newTestCert(ca *NebulaCertificate, key []byte) (*NebulaCertificate, []byte, []byte, error) {
issuer, err := ca.Sha256Sum()
if err != nil {
return nil, nil, nil, err
}
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pub, rawPriv := x25519Keypair()
nc := &NebulaCertificate{
Details: NebulaCertificateDetails{
Name: "testing",
Ips: []*net.IPNet{
{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
},
Subnets: []*net.IPNet{
{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: before,
NotAfter: after,
PublicKey: pub,
IsCA: false,
Issuer: issuer,
},
}
err = nc.Sign(key)
if err != nil {
return nil, nil, nil, err
}
return nc, pub, rawPriv, nil
}
func x25519Keypair() ([]byte, []byte) {
var pubkey, privkey [32]byte
if _, err := io.ReadFull(rand.Reader, privkey[:]); err != nil {
panic(err)
}
curve25519.ScalarBaseMult(&pubkey, &privkey)
return pubkey[:], privkey[:]
}

147
cidr_radix.go Normal file
View File

@ -0,0 +1,147 @@
package nebula
import (
"encoding/binary"
"fmt"
"net"
)
type CIDRNode struct {
left *CIDRNode
right *CIDRNode
parent *CIDRNode
value interface{}
}
type CIDRTree struct {
root *CIDRNode
}
const (
startbit = uint32(0x80000000)
)
func NewCIDRTree() *CIDRTree {
tree := new(CIDRTree)
tree.root = &CIDRNode{}
return tree
}
func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
bit := startbit
node := tree.root
next := tree.root
ip := ip2int(cidr.IP)
mask := ip2int(cidr.Mask)
// Find our last ancestor in the tree
for bit&mask != 0 {
if ip&bit != 0 {
next = node.right
} else {
next = node.left
}
if next == nil {
break
}
bit = bit >> 1
node = next
}
// We already have this range so update the value
if next != nil {
node.value = val
return
}
// Build up the rest of the tree we don't already have
for bit&mask != 0 {
next = &CIDRNode{}
next.parent = node
if ip&bit != 0 {
node.right = next
} else {
node.left = next
}
bit >>= 1
node = next
}
// Final node marks our cidr, set the value
node.value = val
}
// Finds the first match, which way be the least specific
func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
bit := startbit
node := tree.root
for node != nil {
if node.value != nil {
return node.value
}
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit >>= 1
}
return value
}
// Finds the most specific match
func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
bit := startbit
node := tree.root
lastNode := node
for node != nil {
lastNode = node
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit >>= 1
}
if bit == 0 && lastNode != nil {
value = lastNode.value
}
return value
}
// A helper type to avoid converting to IP when logging
type IntIp uint32
func (ip IntIp) String() string {
return fmt.Sprintf("%v", int2ip(uint32(ip)))
}
func (ip IntIp) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("\"%s\"", int2ip(uint32(ip)).String())), nil
}
func ip2int(ip []byte) uint32 {
if len(ip) == 16 {
return binary.BigEndian.Uint32(ip[12:16])
}
return binary.BigEndian.Uint32(ip)
}
func int2ip(nn uint32) net.IP {
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, nn)
return ip
}

118
cidr_radix_test.go Normal file
View File

@ -0,0 +1,118 @@
package nebula
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCIDRTree_Contains(t *testing.T) {
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a")
tree.AddCIDR(getCIDR("4.1.1.1/32"), "4b")
tree.AddCIDR(getCIDR("4.1.2.1/32"), "4c")
tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
tests := []struct {
Result interface{}
IP string
}{
{"1", "1.0.0.0"},
{"1", "1.255.255.255"},
{"2", "2.1.0.0"},
{"2", "2.1.255.255"},
{"3", "3.1.1.0"},
{"3", "3.1.1.255"},
{"4a", "4.1.1.255"},
{"4a", "4.1.1.1"},
{"5", "240.0.0.0"},
{"5", "255.255.255.255"},
{nil, "239.0.0.0"},
{nil, "4.1.2.2"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.Contains(ip2int(net.ParseIP(tt.IP))))
}
tree = NewCIDRTree()
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0"))))
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255"))))
}
func TestCIDRTree_Match(t *testing.T) {
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("4.1.1.0/32"), "1a")
tree.AddCIDR(getCIDR("4.1.1.1/32"), "1b")
tests := []struct {
Result interface{}
IP string
}{
{"1a", "4.1.1.0"},
{"1b", "4.1.1.1"},
}
for _, tt := range tests {
assert.Equal(t, tt.Result, tree.Match(ip2int(net.ParseIP(tt.IP))))
}
tree = NewCIDRTree()
tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0"))))
assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255"))))
}
func BenchmarkCIDRTree_Contains(b *testing.B) {
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("1.1.0.0/16"), "1")
tree.AddCIDR(getCIDR("1.2.1.1/32"), "1")
tree.AddCIDR(getCIDR("192.2.1.1/32"), "1")
tree.AddCIDR(getCIDR("172.2.1.1/32"), "1")
ip := ip2int(net.ParseIP("1.2.1.1"))
b.Run("found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Contains(ip)
}
})
ip = ip2int(net.ParseIP("1.2.1.255"))
b.Run("not found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Contains(ip)
}
})
}
func BenchmarkCIDRTree_Match(b *testing.B) {
tree := NewCIDRTree()
tree.AddCIDR(getCIDR("1.1.0.0/16"), "1")
tree.AddCIDR(getCIDR("1.2.1.1/32"), "1")
tree.AddCIDR(getCIDR("192.2.1.1/32"), "1")
tree.AddCIDR(getCIDR("172.2.1.1/32"), "1")
ip := ip2int(net.ParseIP("1.2.1.1"))
b.Run("found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
ip = ip2int(net.ParseIP("1.2.1.255"))
b.Run("not found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
}
func getCIDR(s string) *net.IPNet {
_, c, _ := net.ParseCIDR(s)
return c
}

124
cmd/nebula-cert/ca.go Normal file
View File

@ -0,0 +1,124 @@
package main
import (
"crypto/rand"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"strings"
"time"
"golang.org/x/crypto/ed25519"
"github.com/slackhq/nebula/cert"
)
type caFlags struct {
set *flag.FlagSet
name *string
duration *time.Duration
outKeyPath *string
outCertPath *string
groups *string
}
func newCaFlags() *caFlags {
cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
cf.set.Usage = func() {}
cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
return &cf
}
func ca(args []string, out io.Writer, errOut io.Writer) error {
cf := newCaFlags()
err := cf.set.Parse(args)
if err != nil {
return err
}
if err := mustFlagString("name", cf.name); err != nil {
return err
}
if err := mustFlagString("out-key", cf.outKeyPath); err != nil {
return err
}
if err := mustFlagString("out-crt", cf.outCertPath); err != nil {
return err
}
if *cf.duration <= 0 {
return &helpError{"-duration must be greater than 0"}
}
groups := []string{}
if *cf.groups != "" {
for _, rg := range strings.Split(*cf.groups, ",") {
g := strings.TrimSpace(rg)
if g != "" {
groups = append(groups, g)
}
}
}
pub, rawPriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return fmt.Errorf("error while generating ed25519 keys: %s", err)
}
nc := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: *cf.name,
Groups: groups,
NotBefore: time.Now(),
NotAfter: time.Now().Add(*cf.duration),
PublicKey: pub,
IsCA: true,
},
}
if _, err := os.Stat(*cf.outKeyPath); err == nil {
return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
}
if _, err := os.Stat(*cf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
}
err = nc.Sign(rawPriv)
if err != nil {
return fmt.Errorf("error while signing: %s", err)
}
err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalEd25519PrivateKey(rawPriv), 0600)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
b, err := nc.MarshalToPEM()
if err != nil {
return fmt.Errorf("error while marshalling certificate: %s", err)
}
err = ioutil.WriteFile(*cf.outCertPath, b, 0600)
if err != nil {
return fmt.Errorf("error while writing out-crt: %s", err)
}
return nil
}
func caSummary() string {
return "ca <flags>: create a self signed certificate authority"
}
func caHelp(out io.Writer) {
cf := newCaFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + caSummary() + "\n"))
cf.set.SetOutput(out)
cf.set.PrintDefaults()
}

132
cmd/nebula-cert/ca_test.go Normal file
View File

@ -0,0 +1,132 @@
package main
import (
"bytes"
"io/ioutil"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/slackhq/nebula/cert"
)
//TODO: test file permissions
func Test_caSummary(t *testing.T) {
assert.Equal(t, "ca <flags>: create a self signed certificate authority", caSummary())
}
func Test_caHelp(t *testing.T) {
ob := &bytes.Buffer{}
caHelp(ob)
assert.Equal(
t,
"Usage of "+os.Args[0]+" ca <flags>: create a self signed certificate authority\n"+
" -duration duration\n"+
" \tOptional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\" (default 8760h0m0s)\n"+
" -groups string\n"+
" \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
" -name string\n"+
" \tRequired: name of the certificate authority\n"+
" -out-crt string\n"+
" \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
" -out-key string\n"+
" \tOptional: path to write the private key to (default \"ca.key\")\n",
ob.String(),
)
}
func Test_ca(t *testing.T) {
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
// required args
assertHelpError(t, ca([]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb), "-name is required")
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// failed key write
ob.Reset()
eb.Reset()
args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
assert.EqualError(t, ca(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// create temp key file
keyF, err := ioutil.TempFile("", "test.key")
assert.Nil(t, err)
os.Remove(keyF.Name())
// failed cert write
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// create temp cert file
crtF, err := ioutil.TempFile("", "test.crt")
assert.Nil(t, err)
os.Remove(crtF.Name())
os.Remove(keyF.Name())
// test proper cert with removed empty groups and subnets
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb))
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// read cert and key files
rb, _ := ioutil.ReadFile(keyF.Name())
lKey, b, err := cert.UnmarshalEd25519PrivateKey(rb)
assert.Len(t, b, 0)
assert.Nil(t, err)
assert.Len(t, lKey, 64)
rb, _ = ioutil.ReadFile(crtF.Name())
lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
assert.Len(t, b, 0)
assert.Nil(t, err)
assert.Equal(t, "test", lCrt.Details.Name)
assert.Len(t, lCrt.Details.Ips, 0)
assert.True(t, lCrt.Details.IsCA)
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups)
assert.Len(t, lCrt.Details.Subnets, 0)
assert.Len(t, lCrt.Details.PublicKey, 32)
assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore))
assert.Equal(t, "", lCrt.Details.Issuer)
assert.True(t, lCrt.CheckSignature(lCrt.Details.PublicKey))
// create valid cert/key for overwrite tests
os.Remove(keyF.Name())
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb))
// test that we won't overwrite existing certificate file
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb), "refusing to overwrite existing CA key: "+keyF.Name())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// test that we won't overwrite existing key file
os.Remove(keyF.Name())
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb), "refusing to overwrite existing CA cert: "+crtF.Name())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
os.Remove(keyF.Name())
}

65
cmd/nebula-cert/keygen.go Normal file
View File

@ -0,0 +1,65 @@
package main
import (
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"github.com/slackhq/nebula/cert"
)
type keygenFlags struct {
set *flag.FlagSet
outKeyPath *string
outPubPath *string
}
func newKeygenFlags() *keygenFlags {
cf := keygenFlags{set: flag.NewFlagSet("keygen", flag.ContinueOnError)}
cf.set.Usage = func() {}
cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to")
cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to")
return &cf
}
func keygen(args []string, out io.Writer, errOut io.Writer) error {
cf := newKeygenFlags()
err := cf.set.Parse(args)
if err != nil {
return err
}
if err := mustFlagString("out-key", cf.outKeyPath); err != nil {
return err
}
if err := mustFlagString("out-pub", cf.outPubPath); err != nil {
return err
}
pub, rawPriv := x25519Keypair()
err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalX25519PrivateKey(rawPriv), 0600)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
err = ioutil.WriteFile(*cf.outPubPath, cert.MarshalX25519PublicKey(pub), 0600)
if err != nil {
return fmt.Errorf("error while writing out-pub: %s", err)
}
return nil
}
func keygenSummary() string {
return "keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`"
}
func keygenHelp(out io.Writer) {
cf := newKeygenFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
cf.set.SetOutput(out)
cf.set.PrintDefaults()
}

View File

@ -0,0 +1,92 @@
package main
import (
"bytes"
"io/ioutil"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/slackhq/nebula/cert"
)
//TODO: test file permissions
func Test_keygenSummary(t *testing.T) {
assert.Equal(t, "keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary())
}
func Test_keygenHelp(t *testing.T) {
ob := &bytes.Buffer{}
keygenHelp(ob)
assert.Equal(
t,
"Usage of "+os.Args[0]+" keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`\n"+
" -out-key string\n"+
" \tRequired: path to write the private key to\n"+
" -out-pub string\n"+
" \tRequired: path to write the public key to\n",
ob.String(),
)
}
func Test_keygen(t *testing.T) {
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
// required args
assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required")
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required")
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// failed key write
ob.Reset()
eb.Reset()
args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"}
assert.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// create temp key file
keyF, err := ioutil.TempFile("", "test.key")
assert.Nil(t, err)
defer os.Remove(keyF.Name())
// failed pub write
ob.Reset()
eb.Reset()
args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()}
assert.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// create temp pub file
pubF, err := ioutil.TempFile("", "test.pub")
assert.Nil(t, err)
defer os.Remove(pubF.Name())
// test proper keygen
ob.Reset()
eb.Reset()
args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, keygen(args, ob, eb))
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// read cert and key files
rb, _ := ioutil.ReadFile(keyF.Name())
lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
assert.Len(t, b, 0)
assert.Nil(t, err)
assert.Len(t, lKey, 32)
rb, _ = ioutil.ReadFile(pubF.Name())
lPub, b, err := cert.UnmarshalX25519PublicKey(rb)
assert.Len(t, b, 0)
assert.Nil(t, err)
assert.Len(t, lPub, 32)
}

137
cmd/nebula-cert/main.go Normal file
View File

@ -0,0 +1,137 @@
package main
import (
"flag"
"fmt"
"io"
"os"
)
var Build string
type helpError struct {
s string
}
func (he *helpError) Error() string {
return he.s
}
func newHelpErrorf(s string, v ...interface{}) error {
return &helpError{s: fmt.Sprintf(s, v...)}
}
func main() {
flag.Usage = func() {
help("", os.Stderr)
os.Exit(1)
}
printVersion := flag.Bool("version", false, "Print version")
flagHelp := flag.Bool("help", false, "Print command line usage")
flagH := flag.Bool("h", false, "Print command line usage")
printUsage := false
flag.Parse()
if *flagH || *flagHelp {
printUsage = true
}
args := flag.Args()
if *printVersion {
fmt.Printf("Version: %v", Build)
os.Exit(0)
}
if len(args) < 1 {
if printUsage {
help("", os.Stderr)
os.Exit(0)
}
help("No mode was provided", os.Stderr)
os.Exit(1)
} else if printUsage {
handleError(args[0], &helpError{}, os.Stderr)
os.Exit(0)
}
var err error
switch args[0] {
case "ca":
err = ca(args[1:], os.Stdout, os.Stderr)
case "keygen":
err = keygen(args[1:], os.Stdout, os.Stderr)
case "sign":
err = signCert(args[1:], os.Stdout, os.Stderr)
case "print":
err = printCert(args[1:], os.Stdout, os.Stderr)
case "verify":
err = verify(args[1:], os.Stdout, os.Stderr)
default:
err = fmt.Errorf("unknown mode: %s", args[0])
}
if err != nil {
os.Exit(handleError(args[0], err, os.Stderr))
}
}
func handleError(mode string, e error, out io.Writer) int {
code := 1
// Handle -help, -h flags properly
if e == flag.ErrHelp {
code = 0
e = &helpError{}
} else if e != nil && e.Error() != "" {
fmt.Fprintln(out, "Error:", e)
}
switch e.(type) {
case *helpError:
switch mode {
case "ca":
caHelp(out)
case "keygen":
keygenHelp(out)
case "sign":
signHelp(out)
case "print":
printHelp(out)
case "verify":
verifyHelp(out)
}
}
return code
}
func help(err string, out io.Writer) {
if err != "" {
fmt.Fprintln(out, "Error:", err)
fmt.Fprintln(out, "")
}
fmt.Fprintf(out, "Usage of %s <global flags> <mode>:\n", os.Args[0])
fmt.Fprintln(out, " Global flags:")
fmt.Fprintln(out, " -version: Prints the version")
fmt.Fprintln(out, " -h, -help: Prints this help message")
fmt.Fprintln(out, "")
fmt.Fprintln(out, " Modes:")
fmt.Fprintln(out, " "+caSummary())
fmt.Fprintln(out, " "+keygenSummary())
fmt.Fprintln(out, " "+signSummary())
fmt.Fprintln(out, " "+printSummary())
fmt.Fprintln(out, " "+verifySummary())
}
func mustFlagString(name string, val *string) error {
if *val == "" {
return newHelpErrorf("-%s is required", name)
}
return nil
}

View File

@ -0,0 +1,81 @@
package main
import (
"bytes"
"errors"
"github.com/stretchr/testify/assert"
"io"
"os"
"testing"
)
//TODO: all flag parsing continueOnError will print to stderr on its own currently
func Test_help(t *testing.T) {
expected := "Usage of " + os.Args[0] + " <global flags> <mode>:\n" +
" Global flags:\n" +
" -version: Prints the version\n" +
" -h, -help: Prints this help message\n\n" +
" Modes:\n" +
" " + caSummary() + "\n" +
" " + keygenSummary() + "\n" +
" " + signSummary() + "\n" +
" " + printSummary() + "\n" +
" " + verifySummary() + "\n"
ob := &bytes.Buffer{}
// No error test
help("", ob)
assert.Equal(
t,
expected,
ob.String(),
)
// Error test
ob.Reset()
help("test error", ob)
assert.Equal(
t,
"Error: test error\n\n"+expected,
ob.String(),
)
}
func Test_handleError(t *testing.T) {
ob := &bytes.Buffer{}
// normal error
handleError("", errors.New("test error"), ob)
assert.Equal(t, "Error: test error\n", ob.String())
// unknown mode help error
ob.Reset()
handleError("", newHelpErrorf("test %s", "error"), ob)
assert.Equal(t, "Error: test error\n", ob.String())
// test all modes with help error
modes := map[string]func(io.Writer){"ca": caHelp, "print": printHelp, "sign": signHelp, "verify": verifyHelp}
eb := &bytes.Buffer{}
for mode, fn := range modes {
ob.Reset()
eb.Reset()
fn(eb)
handleError(mode, newHelpErrorf("test %s", "error"), ob)
assert.Equal(t, "Error: test error\n"+eb.String(), ob.String())
}
}
func assertHelpError(t *testing.T, err error, msg string) {
switch err.(type) {
case *helpError:
// good
default:
t.Fatal("err was not a helpError")
}
assert.EqualError(t, err, msg)
}

80
cmd/nebula-cert/print.go Normal file
View File

@ -0,0 +1,80 @@
package main
import (
"encoding/json"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"github.com/slackhq/nebula/cert"
"strings"
)
type printFlags struct {
set *flag.FlagSet
json *bool
path *string
}
func newPrintFlags() *printFlags {
pf := printFlags{set: flag.NewFlagSet("print", flag.ContinueOnError)}
pf.set.Usage = func() {}
pf.json = pf.set.Bool("json", false, "Optional: outputs certificates in json format")
pf.path = pf.set.String("path", "", "Required: path to the certificate")
return &pf
}
func printCert(args []string, out io.Writer, errOut io.Writer) error {
pf := newPrintFlags()
err := pf.set.Parse(args)
if err != nil {
return err
}
if err := mustFlagString("path", pf.path); err != nil {
return err
}
rawCert, err := ioutil.ReadFile(*pf.path)
if err != nil {
return fmt.Errorf("unable to read cert; %s", err)
}
var c *cert.NebulaCertificate
for {
c, rawCert, err = cert.UnmarshalNebulaCertificateFromPEM(rawCert)
if err != nil {
return fmt.Errorf("error while unmarshaling cert: %s", err)
}
if *pf.json {
b, _ := json.Marshal(c)
out.Write(b)
out.Write([]byte("\n"))
} else {
out.Write([]byte(c.String()))
out.Write([]byte("\n"))
}
if rawCert == nil || len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" {
break
}
}
return nil
}
func printSummary() string {
return "print <flags>: prints details about a certificate"
}
func printHelp(out io.Writer) {
pf := newPrintFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + printSummary() + "\n"))
pf.set.SetOutput(out)
pf.set.PrintDefaults()
}

View File

@ -0,0 +1,119 @@
package main
import (
"bytes"
"github.com/stretchr/testify/assert"
"io/ioutil"
"os"
"github.com/slackhq/nebula/cert"
"testing"
"time"
)
func Test_printSummary(t *testing.T) {
assert.Equal(t, "print <flags>: prints details about a certificate", printSummary())
}
func Test_printHelp(t *testing.T) {
ob := &bytes.Buffer{}
printHelp(ob)
assert.Equal(
t,
"Usage of "+os.Args[0]+" print <flags>: prints details about a certificate\n"+
" -json\n"+
" \tOptional: outputs certificates in json format\n"+
" -path string\n"+
" \tRequired: path to the certificate\n",
ob.String(),
)
}
func Test_printCert(t *testing.T) {
// Orient our local time and avoid headaches
time.Local = time.UTC
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
// no path
err := printCert([]string{}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assertHelpError(t, err, "-path is required")
// no cert at path
ob.Reset()
eb.Reset()
err = printCert([]string{"-path", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
// invalid cert at path
ob.Reset()
eb.Reset()
tf, err := ioutil.TempFile("", "print-cert")
assert.Nil(t, err)
defer os.Remove(tf.Name())
tf.WriteString("-----BEGIN NOPE-----")
err = printCert([]string{"-path", tf.Name()}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
// test multiple certs
ob.Reset()
eb.Reset()
tf.Truncate(0)
tf.Seek(0, 0)
c := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test",
Groups: []string{"hi"},
PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
},
Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
}
p, _ := c.MarshalToPEM()
tf.Write(p)
tf.Write(p)
tf.Write(p)
err = printCert([]string{"-path", tf.Name()}, ob, eb)
assert.Nil(t, err)
assert.Equal(
t,
"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\n",
ob.String(),
)
assert.Equal(t, "", eb.String())
// test json
ob.Reset()
eb.Reset()
tf.Truncate(0)
tf.Seek(0, 0)
c = cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test",
Groups: []string{"hi"},
PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
},
Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
}
p, _ = c.MarshalToPEM()
tf.Write(p)
tf.Write(p)
tf.Write(p)
err = printCert([]string{"-json", "-path", tf.Name()}, ob, eb)
assert.Nil(t, err)
assert.Equal(
t,
"{\"details\":{\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n",
ob.String(),
)
assert.Equal(t, "", eb.String())
}

227
cmd/nebula-cert/sign.go Normal file
View File

@ -0,0 +1,227 @@
package main
import (
"crypto/rand"
"flag"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"strings"
"time"
"golang.org/x/crypto/curve25519"
"github.com/slackhq/nebula/cert"
)
type signFlags struct {
set *flag.FlagSet
caKeyPath *string
caCertPath *string
name *string
ip *string
duration *time.Duration
inPubPath *string
outKeyPath *string
outCertPath *string
groups *string
subnets *string
}
func newSignFlags() *signFlags {
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
sf.set.Usage = func() {}
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
sf.ip = sf.set.String("ip", "", "Required: ip and network in CIDR notation to assign the cert")
sf.duration = sf.set.Duration("duration", 0, "Required: how long the cert should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
sf.subnets = sf.set.String("subnets", "", "Optional: comma seperated list of subnet this cert can serve for")
return &sf
}
func signCert(args []string, out io.Writer, errOut io.Writer) error {
sf := newSignFlags()
err := sf.set.Parse(args)
if err != nil {
return err
}
if err := mustFlagString("ca-key", sf.caKeyPath); err != nil {
return err
}
if err := mustFlagString("ca-crt", sf.caCertPath); err != nil {
return err
}
if err := mustFlagString("name", sf.name); err != nil {
return err
}
if err := mustFlagString("ip", sf.ip); err != nil {
return err
}
if *sf.inPubPath != "" && *sf.outKeyPath != "" {
return newHelpErrorf("cannot set both -in-pub and -out-key")
}
rawCAKey, err := ioutil.ReadFile(*sf.caKeyPath)
if err != nil {
return fmt.Errorf("error while reading ca-key: %s", err)
}
caKey, _, err := cert.UnmarshalEd25519PrivateKey(rawCAKey)
if err != nil {
return fmt.Errorf("error while parsing ca-key: %s", err)
}
rawCACert, err := ioutil.ReadFile(*sf.caCertPath)
if err != nil {
return fmt.Errorf("error while reading ca-crt: %s", err)
}
caCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCACert)
if err != nil {
return fmt.Errorf("error while parsing ca-crt: %s", err)
}
issuer, err := caCert.Sha256Sum()
if err != nil {
return fmt.Errorf("error while getting -ca-crt fingerprint: %s", err)
}
if caCert.Expired(time.Now()) {
return fmt.Errorf("ca certificate is expired")
}
// if no duration is given, expire one second before the root expires
if *sf.duration <= 0 {
*sf.duration = time.Until(caCert.Details.NotAfter) - time.Second*1
}
if caCert.Details.NotAfter.Before(time.Now().Add(*sf.duration)) {
return fmt.Errorf("refusing to generate certificate with duration beyond root expiration: %s", caCert.Details.NotAfter)
}
ip, ipNet, err := net.ParseCIDR(*sf.ip)
if err != nil {
return newHelpErrorf("invalid ip definition: %s", err)
}
ipNet.IP = ip
groups := []string{}
if *sf.groups != "" {
for _, rg := range strings.Split(*sf.groups, ",") {
g := strings.TrimSpace(rg)
if g != "" {
groups = append(groups, g)
}
}
}
subnets := []*net.IPNet{}
if *sf.subnets != "" {
for _, rs := range strings.Split(*sf.subnets, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
_, s, err := net.ParseCIDR(rs)
if err != nil {
return newHelpErrorf("invalid subnet definition: %s", err)
}
subnets = append(subnets, s)
}
}
}
var pub, rawPriv []byte
if *sf.inPubPath != "" {
rawPub, err := ioutil.ReadFile(*sf.inPubPath)
if err != nil {
return fmt.Errorf("error while reading in-pub: %s", err)
}
pub, _, err = cert.UnmarshalX25519PublicKey(rawPub)
if err != nil {
return fmt.Errorf("error while parsing in-pub: %s", err)
}
} else {
pub, rawPriv = x25519Keypair()
}
nc := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: *sf.name,
Ips: []*net.IPNet{ipNet},
Groups: groups,
Subnets: subnets,
NotBefore: time.Now(),
NotAfter: time.Now().Add(*sf.duration),
PublicKey: pub,
IsCA: false,
Issuer: issuer,
},
}
if *sf.outKeyPath == "" {
*sf.outKeyPath = *sf.name + ".key"
}
if *sf.outCertPath == "" {
*sf.outCertPath = *sf.name + ".crt"
}
if _, err := os.Stat(*sf.outKeyPath); err == nil {
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
}
if _, err := os.Stat(*sf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
}
err = nc.Sign(caKey)
if err != nil {
return fmt.Errorf("error while signing: %s", err)
}
if *sf.inPubPath == "" {
err = ioutil.WriteFile(*sf.outKeyPath, cert.MarshalX25519PrivateKey(rawPriv), 0600)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
}
b, err := nc.MarshalToPEM()
if err != nil {
return fmt.Errorf("error while marshalling certificate: %s", err)
}
err = ioutil.WriteFile(*sf.outCertPath, b, 0600)
if err != nil {
return fmt.Errorf("error while writing out-crt: %s", err)
}
return nil
}
func x25519Keypair() ([]byte, []byte) {
var pubkey, privkey [32]byte
if _, err := io.ReadFull(rand.Reader, privkey[:]); err != nil {
panic(err)
}
curve25519.ScalarBaseMult(&pubkey, &privkey)
return pubkey[:], privkey[:]
}
func signSummary() string {
return "sign <flags>: create and sign a certificate"
}
func signHelp(out io.Writer) {
sf := newSignFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + signSummary() + "\n"))
sf.set.SetOutput(out)
sf.set.PrintDefaults()
}

View File

@ -0,0 +1,281 @@
package main
import (
"bytes"
"crypto/rand"
"io/ioutil"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
"github.com/slackhq/nebula/cert"
)
//TODO: test file permissions
func Test_signSummary(t *testing.T) {
assert.Equal(t, "sign <flags>: create and sign a certificate", signSummary())
}
func Test_signHelp(t *testing.T) {
ob := &bytes.Buffer{}
signHelp(ob)
assert.Equal(
t,
"Usage of "+os.Args[0]+" sign <flags>: create and sign a certificate\n"+
" -ca-crt string\n"+
" \tOptional: path to the signing CA cert (default \"ca.crt\")\n"+
" -ca-key string\n"+
" \tOptional: path to the signing CA key (default \"ca.key\")\n"+
" -duration duration\n"+
" \tRequired: how long the cert should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"\n"+
" -groups string\n"+
" \tOptional: comma separated list of groups\n"+
" -in-pub string\n"+
" \tOptional (if out-key not set): path to read a previously generated public key\n"+
" -ip string\n"+
" \tRequired: ip and network in CIDR notation to assign the cert\n"+
" -name string\n"+
" \tRequired: name of the cert, usually a hostname\n"+
" -out-crt string\n"+
" \tOptional: path to write the certificate to\n"+
" -out-key string\n"+
" \tOptional (if in-pub not set): path to write the private key to\n"+
" -subnets string\n"+
" \tOptional: comma seperated list of subnet this cert can serve for\n",
ob.String(),
)
}
func Test_signCert(t *testing.T) {
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
// required args
assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb), "-name is required")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb), "-ip is required")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// cannot set -in-pub and -out-key
assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb), "cannot set both -in-pub and -out-key")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// failed to read key
ob.Reset()
eb.Reset()
args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb), "error while reading ca-key: open ./nope: "+NoSuchFileError)
// failed to unmarshal key
ob.Reset()
eb.Reset()
caKeyF, err := ioutil.TempFile("", "sign-cert.key")
assert.Nil(t, err)
defer os.Remove(caKeyF.Name())
args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb), "error while parsing ca-key: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// Write a proper ca key for later
ob.Reset()
eb.Reset()
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
caKeyF.Write(cert.MarshalEd25519PrivateKey(caPriv))
// failed to read cert
args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// failed to unmarshal cert
ob.Reset()
eb.Reset()
caCrtF, err := ioutil.TempFile("", "sign-cert.crt")
assert.Nil(t, err)
defer os.Remove(caCrtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// write a proper ca cert for later
ca := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "ca",
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute * 200),
PublicKey: caPub,
IsCA: true,
},
}
b, _ := ca.MarshalToPEM()
caCrtF.Write(b)
// failed to read pub
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb), "error while reading in-pub: open ./nope: "+NoSuchFileError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// failed to unmarshal pub
ob.Reset()
eb.Reset()
inPubF, err := ioutil.TempFile("", "in.pub")
assert.Nil(t, err)
defer os.Remove(inPubF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb), "error while parsing in-pub: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// write a proper pub for later
ob.Reset()
eb.Reset()
inPub, _ := x25519Keypair()
inPubF.Write(cert.MarshalX25519PublicKey(inPub))
// bad ip cidr
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb), "invalid ip definition: invalid CIDR address: a1.1.1.1/24")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// bad subnet cidr
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
assertHelpError(t, signCert(args, ob, eb), "invalid subnet definition: invalid CIDR address: a")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// failed key write
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
assert.EqualError(t, signCert(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// create temp key file
keyF, err := ioutil.TempFile("", "test.key")
assert.Nil(t, err)
os.Remove(keyF.Name())
// failed cert write
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
assert.EqualError(t, signCert(args, ob, eb), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
os.Remove(keyF.Name())
// create temp cert file
crtF, err := ioutil.TempFile("", "test.crt")
assert.Nil(t, err)
os.Remove(crtF.Name())
// test proper cert with removed empty groups and subnets
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// read cert and key files
rb, _ := ioutil.ReadFile(keyF.Name())
lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
assert.Len(t, b, 0)
assert.Nil(t, err)
assert.Len(t, lKey, 32)
rb, _ = ioutil.ReadFile(crtF.Name())
lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
assert.Len(t, b, 0)
assert.Nil(t, err)
assert.Equal(t, "test", lCrt.Details.Name)
assert.Equal(t, "1.1.1.1/24", lCrt.Details.Ips[0].String())
assert.Len(t, lCrt.Details.Ips, 1)
assert.False(t, lCrt.Details.IsCA)
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups)
assert.Len(t, lCrt.Details.Subnets, 3)
assert.Len(t, lCrt.Details.PublicKey, 32)
assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore))
sns := []string{}
for _, sn := range lCrt.Details.Subnets {
sns = append(sns, sn.String())
}
assert.Equal(t, []string{"10.1.1.1/32", "10.2.2.2/32", "10.5.5.5/32"}, sns)
issuer, _ := ca.Sha256Sum()
assert.Equal(t, issuer, lCrt.Details.Issuer)
assert.True(t, lCrt.CheckSignature(caPub))
// test proper cert with in-pub
os.Remove(keyF.Name())
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
assert.Nil(t, signCert(args, ob, eb))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// read cert file and check pub key matches in-pub
rb, _ = ioutil.ReadFile(crtF.Name())
lCrt, b, err = cert.UnmarshalNebulaCertificateFromPEM(rb)
assert.Len(t, b, 0)
assert.Nil(t, err)
assert.Equal(t, lCrt.Details.PublicKey, inPub)
// test refuse to sign cert with duration beyond root
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb), "refusing to generate certificate with duration beyond root expiration: "+ca.Details.NotAfter.Format("2006-01-02 15:04:05 +0000 UTC"))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// create valid cert/key for overwrite tests
os.Remove(keyF.Name())
os.Remove(crtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb))
// test that we won't overwrite existing key file
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb), "refusing to overwrite existing key: "+keyF.Name())
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// test that we won't overwrite existing certificate file
os.Remove(keyF.Name())
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb), "refusing to overwrite existing cert: "+crtF.Name())
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
}

View File

@ -0,0 +1,4 @@
package main
const NoSuchFileError = "no such file or directory"
const NoSuchDirError = "no such file or directory"

View File

@ -0,0 +1,4 @@
package main
const NoSuchFileError = "The system cannot find the file specified."
const NoSuchDirError = "The system cannot find the path specified."

86
cmd/nebula-cert/verify.go Normal file
View File

@ -0,0 +1,86 @@
package main
import (
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"github.com/slackhq/nebula/cert"
"strings"
"time"
)
type verifyFlags struct {
set *flag.FlagSet
caPath *string
certPath *string
}
func newVerifyFlags() *verifyFlags {
vf := verifyFlags{set: flag.NewFlagSet("verify", flag.ContinueOnError)}
vf.set.Usage = func() {}
vf.caPath = vf.set.String("ca", "", "Required: path to a file containing one or more ca certificates")
vf.certPath = vf.set.String("crt", "", "Required: path to a file containing a single certificate")
return &vf
}
func verify(args []string, out io.Writer, errOut io.Writer) error {
vf := newVerifyFlags()
err := vf.set.Parse(args)
if err != nil {
return err
}
if err := mustFlagString("ca", vf.caPath); err != nil {
return err
}
if err := mustFlagString("crt", vf.certPath); err != nil {
return err
}
rawCACert, err := ioutil.ReadFile(*vf.caPath)
if err != nil {
return fmt.Errorf("error while reading ca: %s", err)
}
caPool := cert.NewCAPool()
for {
rawCACert, err = caPool.AddCACertificate(rawCACert)
if err != nil {
return fmt.Errorf("error while adding ca cert to pool: %s", err)
}
if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
break
}
}
rawCert, err := ioutil.ReadFile(*vf.certPath)
if err != nil {
return fmt.Errorf("unable to read crt; %s", err)
}
c, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
if err != nil {
return fmt.Errorf("error while parsing crt: %s", err)
}
good, err := c.Verify(time.Now(), caPool)
if !good {
return err
}
return nil
}
func verifySummary() string {
return "verify <flags>: verifies a certificate isn't expired and was signed by a trusted authority."
}
func verifyHelp(out io.Writer) {
vf := newVerifyFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
vf.set.SetOutput(out)
vf.set.PrintDefaults()
}

View File

@ -0,0 +1,141 @@
package main
import (
"bytes"
"crypto/rand"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
"io/ioutil"
"os"
"github.com/slackhq/nebula/cert"
"testing"
"time"
)
func Test_verifySummary(t *testing.T) {
assert.Equal(t, "verify <flags>: verifies a certificate isn't expired and was signed by a trusted authority.", verifySummary())
}
func Test_verifyHelp(t *testing.T) {
ob := &bytes.Buffer{}
verifyHelp(ob)
assert.Equal(
t,
"Usage of "+os.Args[0]+" verify <flags>: verifies a certificate isn't expired and was signed by a trusted authority.\n"+
" -ca string\n"+
" \tRequired: path to a file containing one or more ca certificates\n"+
" -crt string\n"+
" \tRequired: path to a file containing a single certificate\n",
ob.String(),
)
}
func Test_verify(t *testing.T) {
time.Local = time.UTC
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
// required args
assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required")
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required")
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// no ca at path
ob.Reset()
eb.Reset()
err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
// invalid ca at path
ob.Reset()
eb.Reset()
caFile, err := ioutil.TempFile("", "verify-ca")
assert.Nil(t, err)
defer os.Remove(caFile.Name())
caFile.WriteString("-----BEGIN NOPE-----")
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
// make a ca for later
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
ca := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test-ca",
NotBefore: time.Now().Add(time.Hour * -1),
NotAfter: time.Now().Add(time.Hour),
PublicKey: caPub,
IsCA: true,
},
}
ca.Sign(caPriv)
b, _ := ca.MarshalToPEM()
caFile.Truncate(0)
caFile.Seek(0, 0)
caFile.Write(b)
// no crt at path
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError)
// invalid crt at path
ob.Reset()
eb.Reset()
certFile, err := ioutil.TempFile("", "verify-cert")
assert.Nil(t, err)
defer os.Remove(certFile.Name())
certFile.WriteString("-----BEGIN NOPE-----")
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
// unverifiable cert at path
_, badPriv, _ := ed25519.GenerateKey(rand.Reader)
certPub, _ := x25519Keypair()
signer, _ := ca.Sha256Sum()
crt := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test-cert",
NotBefore: time.Now().Add(time.Hour * -1),
NotAfter: time.Now().Add(time.Hour),
PublicKey: certPub,
IsCA: false,
Issuer: signer,
},
}
crt.Sign(badPriv)
b, _ = crt.MarshalToPEM()
certFile.Truncate(0)
certFile.Seek(0, 0)
certFile.Write(b)
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "certificate signature did not match")
// verified cert at path
crt.Sign(caPriv)
b, _ = crt.MarshalToPEM()
certFile.Truncate(0)
certFile.Seek(0, 0)
certFile.Write(b)
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.Nil(t, err)
}

43
cmd/nebula/main.go Normal file
View File

@ -0,0 +1,43 @@
package main
import (
"flag"
"fmt"
"os"
"github.com/slackhq/nebula"
)
// A version string that can be set with
//
// -ldflags "-X main.Build=SOMEVERSION"
//
// at compile-time.
var Build string
func main() {
configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config")
printVersion := flag.Bool("version", false, "Print version")
printUsage := flag.Bool("help", false, "Print command line usage")
flag.Parse()
if *printVersion {
fmt.Printf("Build: %s\n", Build)
os.Exit(0)
}
if *printUsage {
flag.Usage()
os.Exit(0)
}
if *configPath == "" {
fmt.Println("-config flag must be set")
flag.Usage()
os.Exit(1)
}
nebula.Main(*configPath, *configTest, Build)
}

338
config.go Normal file
View File

@ -0,0 +1,338 @@
package nebula
import (
"fmt"
"github.com/imdario/mergo"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
"io/ioutil"
"os"
"os/signal"
"path/filepath"
"sort"
"strconv"
"strings"
"syscall"
"time"
)
type Config struct {
path string
files []string
Settings map[interface{}]interface{}
oldSettings map[interface{}]interface{}
callbacks []func(*Config)
}
func NewConfig() *Config {
return &Config{
Settings: make(map[interface{}]interface{}),
}
}
// Load will find all yaml files within path and load them in lexical order
func (c *Config) Load(path string) error {
c.path = path
c.files = make([]string, 0)
err := c.resolve(path)
if err != nil {
return err
}
sort.Strings(c.files)
err = c.parse()
if err != nil {
return err
}
return nil
}
// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
// here should decide if they need to make a change to the current process before making the change. HasChanged can be
// used to help decide if a change is necessary.
// These functions should return quickly or spawn their own go routine if they will take a while
func (c *Config) RegisterReloadCallback(f func(*Config)) {
c.callbacks = append(c.callbacks, f)
}
// HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
// k in both the old and new settings will be serialized, the result of the string comparison is returned.
// If k is an empty string the entire config is tested.
// It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
// there is change when there actually wasn't any.
func (c *Config) HasChanged(k string) bool {
if c.oldSettings == nil {
return false
}
var (
nv interface{}
ov interface{}
)
if k == "" {
nv = c.Settings
ov = c.oldSettings
k = "all settings"
} else {
nv = c.get(k, c.Settings)
ov = c.get(k, c.oldSettings)
}
newVals, err := yaml.Marshal(nv)
if err != nil {
l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
}
oldVals, err := yaml.Marshal(ov)
if err != nil {
l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
}
return string(newVals) != string(oldVals)
}
// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
// original path provided to Load. The old settings are shallow copied for change detection after the reload.
func (c *Config) CatchHUP() {
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGHUP)
go func() {
for range ch {
l.Info("Caught HUP, reloading config")
c.ReloadConfig()
}
}()
}
func (c *Config) ReloadConfig() {
c.oldSettings = make(map[interface{}]interface{})
for k, v := range c.Settings {
c.oldSettings[k] = v
}
err := c.Load(c.path)
if err != nil {
l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
return
}
for _, v := range c.callbacks {
v(c)
}
}
// GetString will get the string for k or return the default d if not found or invalid
func (c *Config) GetString(k, d string) string {
r := c.Get(k)
if r == nil {
return d
}
return fmt.Sprintf("%v", r)
}
// GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
func (c *Config) GetStringSlice(k string, d []string) []string {
r := c.Get(k)
if r == nil {
return d
}
rv, ok := r.([]interface{})
if !ok {
return d
}
v := make([]string, len(rv))
for i := 0; i < len(v); i++ {
v[i] = fmt.Sprintf("%v", rv[i])
}
return v
}
// GetMap will get the map for k or return the default d if not found or invalid
func (c *Config) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
r := c.Get(k)
if r == nil {
return d
}
v, ok := r.(map[interface{}]interface{})
if !ok {
return d
}
return v
}
// GetInt will get the int for k or return the default d if not found or invalid
func (c *Config) GetInt(k string, d int) int {
r := c.GetString(k, strconv.Itoa(d))
v, err := strconv.Atoi(r)
if err != nil {
return d
}
return v
}
// GetBool will get the bool for k or return the default d if not found or invalid
func (c *Config) GetBool(k string, d bool) bool {
r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
v, err := strconv.ParseBool(r)
if err != nil {
switch r {
case "y", "yes":
return true
case "n", "no":
return false
}
return d
}
return v
}
// GetDuration will get the duration for k or return the default d if not found or invalid
func (c *Config) GetDuration(k string, d time.Duration) time.Duration {
r := c.GetString(k, "")
v, err := time.ParseDuration(r)
if err != nil {
return d
}
return v
}
func (c *Config) Get(k string) interface{} {
return c.get(k, c.Settings)
}
func (c *Config) get(k string, v interface{}) interface{} {
parts := strings.Split(k, ".")
for _, p := range parts {
m, ok := v.(map[interface{}]interface{})
if !ok {
return nil
}
v, ok = m[p]
if !ok {
return nil
}
}
return v
}
func (c *Config) resolve(path string) error {
i, err := os.Stat(path)
if err != nil {
return nil
}
if !i.IsDir() {
c.addFile(path)
return nil
}
paths, err := readDirNames(path)
if err != nil {
return fmt.Errorf("problem while reading directory %s: %s", path, err)
}
for _, p := range paths {
err := c.resolve(filepath.Join(path, p))
if err != nil {
return err
}
}
return nil
}
func (c *Config) addFile(path string) error {
ext := filepath.Ext(path)
if ext != ".yaml" && ext != ".yml" {
return nil
}
ap, err := filepath.Abs(path)
if err != nil {
return err
}
c.files = append(c.files, ap)
return nil
}
func (c *Config) parse() error {
var m map[interface{}]interface{}
for _, path := range c.files {
b, err := ioutil.ReadFile(path)
if err != nil {
return err
}
var nm map[interface{}]interface{}
err = yaml.Unmarshal(b, &nm)
if err != nil {
return err
}
// We need to use WithAppendSlice so that firewall rules in separate
// files are appended together
err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
m = nm
if err != nil {
return err
}
}
c.Settings = m
return nil
}
func readDirNames(path string) ([]string, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
paths, err := f.Readdirnames(-1)
f.Close()
if err != nil {
return nil, err
}
sort.Strings(paths)
return paths, nil
}
func configLogger(c *Config) error {
// set up our logging level
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
if err != nil {
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
}
l.SetLevel(logLevel)
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
switch logFormat {
case "text":
l.Formatter = &logrus.TextFormatter{}
case "json":
l.Formatter = &logrus.JSONFormatter{}
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
}
return nil
}

141
config_test.go Normal file
View File

@ -0,0 +1,141 @@
package nebula
import (
"github.com/stretchr/testify/assert"
"io/ioutil"
"os"
"path/filepath"
"testing"
"time"
)
func TestConfig_Load(t *testing.T) {
dir, err := ioutil.TempDir("", "config-test")
// invalid yaml
c := NewConfig()
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
// simple multi config merge
c = NewConfig()
os.RemoveAll(dir)
os.Mkdir(dir, 0755)
assert.Nil(t, err)
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
ioutil.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644)
assert.Nil(t, c.Load(dir))
expected := map[interface{}]interface{}{
"outer": map[interface{}]interface{}{
"inner": "override",
},
"new": "hi",
}
assert.Equal(t, expected, c.Settings)
//TODO: test symlinked file
//TODO: test symlinked directory
}
func TestConfig_Get(t *testing.T) {
// test simple type
c := NewConfig()
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
assert.Equal(t, "hi", c.Get("firewall.outbound"))
// test complex type
inner := []map[interface{}]interface{}{{"port": "1", "code": "2"}}
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": inner}
assert.EqualValues(t, inner, c.Get("firewall.outbound"))
// test missing
assert.Nil(t, c.Get("firewall.nope"))
}
func TestConfig_GetStringSlice(t *testing.T) {
c := NewConfig()
c.Settings["slice"] = []interface{}{"one", "two"}
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
}
func TestConfig_GetBool(t *testing.T) {
c := NewConfig()
c.Settings["bool"] = true
assert.Equal(t, true, c.GetBool("bool", false))
c.Settings["bool"] = "true"
assert.Equal(t, true, c.GetBool("bool", false))
c.Settings["bool"] = false
assert.Equal(t, false, c.GetBool("bool", true))
c.Settings["bool"] = "false"
assert.Equal(t, false, c.GetBool("bool", true))
c.Settings["bool"] = "Y"
assert.Equal(t, true, c.GetBool("bool", false))
c.Settings["bool"] = "yEs"
assert.Equal(t, true, c.GetBool("bool", false))
c.Settings["bool"] = "N"
assert.Equal(t, false, c.GetBool("bool", true))
c.Settings["bool"] = "nO"
assert.Equal(t, false, c.GetBool("bool", true))
}
func TestConfig_HasChanged(t *testing.T) {
// No reload has occurred, return false
c := NewConfig()
c.Settings["test"] = "hi"
assert.False(t, c.HasChanged(""))
// Test key change
c = NewConfig()
c.Settings["test"] = "hi"
c.oldSettings = map[interface{}]interface{}{"test": "no"}
assert.True(t, c.HasChanged("test"))
assert.True(t, c.HasChanged(""))
// No key change
c = NewConfig()
c.Settings["test"] = "hi"
c.oldSettings = map[interface{}]interface{}{"test": "hi"}
assert.False(t, c.HasChanged("test"))
assert.False(t, c.HasChanged(""))
}
func TestConfig_ReloadConfig(t *testing.T) {
done := make(chan bool, 1)
dir, err := ioutil.TempDir("", "config-test")
assert.Nil(t, err)
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
c := NewConfig()
assert.Nil(t, c.Load(dir))
assert.False(t, c.HasChanged("outer.inner"))
assert.False(t, c.HasChanged("outer"))
assert.False(t, c.HasChanged(""))
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: ho"), 0644)
c.RegisterReloadCallback(func(c *Config) {
done <- true
})
c.ReloadConfig()
assert.True(t, c.HasChanged("outer.inner"))
assert.True(t, c.HasChanged("outer"))
assert.True(t, c.HasChanged(""))
// Make sure we call the callbacks
select {
case <-done:
case <-time.After(1 * time.Second):
panic("timeout")
}
}

253
connection_manager.go Normal file
View File

@ -0,0 +1,253 @@
package nebula
import (
"github.com/sirupsen/logrus"
"sync"
"time"
)
// TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
// and something like every 10 packets we could lock, send 10, then unlock for a moment
type connectionManager struct {
hostMap *HostMap
in map[uint32]struct{}
inLock *sync.RWMutex
inCount int
out map[uint32]struct{}
outLock *sync.RWMutex
outCount int
TrafficTimer *SystemTimerWheel
intf *Interface
pendingDeletion map[uint32]int
pendingDeletionLock *sync.RWMutex
pendingDeletionTimer *SystemTimerWheel
checkInterval int
pendingDeletionInterval int
// I wanted to call one matLock
}
func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
nc := &connectionManager{
hostMap: intf.hostMap,
in: make(map[uint32]struct{}),
inLock: &sync.RWMutex{},
inCount: 0,
out: make(map[uint32]struct{}),
outLock: &sync.RWMutex{},
outCount: 0,
TrafficTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
intf: intf,
pendingDeletion: make(map[uint32]int),
pendingDeletionLock: &sync.RWMutex{},
pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
checkInterval: checkInterval,
pendingDeletionInterval: pendingDeletionInterval,
}
nc.Start()
return nc
}
func (n *connectionManager) In(ip uint32) {
n.inLock.RLock()
// If this already exists, return
if _, ok := n.in[ip]; ok {
n.inLock.RUnlock()
return
}
n.inLock.RUnlock()
n.inLock.Lock()
n.in[ip] = struct{}{}
n.inLock.Unlock()
}
func (n *connectionManager) Out(ip uint32) {
n.outLock.RLock()
// If this already exists, return
if _, ok := n.out[ip]; ok {
n.outLock.RUnlock()
return
}
n.outLock.RUnlock()
n.outLock.Lock()
// double check since we dropped the lock temporarily
if _, ok := n.out[ip]; ok {
n.outLock.Unlock()
return
}
n.out[ip] = struct{}{}
n.AddTrafficWatch(ip, n.checkInterval)
n.outLock.Unlock()
}
func (n *connectionManager) CheckIn(vpnIP uint32) bool {
n.inLock.RLock()
if _, ok := n.in[vpnIP]; ok {
n.inLock.RUnlock()
return true
}
n.inLock.RUnlock()
return false
}
func (n *connectionManager) ClearIP(ip uint32) {
n.inLock.Lock()
n.outLock.Lock()
delete(n.in, ip)
delete(n.out, ip)
n.inLock.Unlock()
n.outLock.Unlock()
}
func (n *connectionManager) ClearPendingDeletion(ip uint32) {
n.pendingDeletionLock.Lock()
delete(n.pendingDeletion, ip)
n.pendingDeletionLock.Unlock()
}
func (n *connectionManager) AddPendingDeletion(ip uint32) {
n.pendingDeletionLock.Lock()
if _, ok := n.pendingDeletion[ip]; ok {
n.pendingDeletion[ip] += 1
} else {
n.pendingDeletion[ip] = 0
}
n.pendingDeletionTimer.Add(ip, time.Second*time.Duration(n.pendingDeletionInterval))
n.pendingDeletionLock.Unlock()
}
func (n *connectionManager) checkPendingDeletion(ip uint32) bool {
n.pendingDeletionLock.RLock()
if _, ok := n.pendingDeletion[ip]; ok {
n.pendingDeletionLock.RUnlock()
return true
}
n.pendingDeletionLock.RUnlock()
return false
}
func (n *connectionManager) AddTrafficWatch(vpnIP uint32, seconds int) {
n.TrafficTimer.Add(vpnIP, time.Second*time.Duration(seconds))
}
func (n *connectionManager) Start() {
go n.Run()
}
func (n *connectionManager) Run() {
clockSource := time.Tick(500 * time.Millisecond)
for now := range clockSource {
n.HandleMonitorTick(now)
n.HandleDeletionTick(now)
}
}
func (n *connectionManager) HandleMonitorTick(now time.Time) {
n.TrafficTimer.advance(now)
for {
ep := n.TrafficTimer.Purge()
if ep == nil {
break
}
vpnIP := ep.(uint32)
// Check for traffic coming back in from this host.
traf := n.CheckIn(vpnIP)
// If we saw incoming packets from this ip, just return
if traf {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIP)).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status")
}
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
continue
}
// If we didn't we may need to probe or destroy the conn
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
if err != nil {
l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
continue
}
l.WithField("vpnIp", IntIp(vpnIP)).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status")
if hostinfo != nil && hostinfo.ConnectionState != nil {
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
} else {
l.Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
}
n.AddPendingDeletion(vpnIP)
}
}
func (n *connectionManager) HandleDeletionTick(now time.Time) {
n.pendingDeletionTimer.advance(now)
for {
ep := n.pendingDeletionTimer.Purge()
if ep == nil {
break
}
vpnIP := ep.(uint32)
// If we saw incoming packets from this ip, just return
traf := n.CheckIn(vpnIP)
if traf {
l.WithField("vpnIp", IntIp(vpnIP)).
WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
Debug("Tunnel status")
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
continue
}
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
if err != nil {
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
continue
}
// If it comes around on deletion wheel and hasn't resolved itself, delete
if n.checkPendingDeletion(vpnIP) {
cn := ""
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
cn = hostinfo.ConnectionState.peerCert.Details.Name
}
l.WithField("vpnIp", IntIp(vpnIP)).
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
WithField("certName", cn).
Info("Tunnel status")
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
// TODO: This is only here to let tests work. Should do proper mocking
if n.intf.lightHouse != nil {
n.intf.lightHouse.DeleteVpnIP(vpnIP)
}
n.hostMap.DeleteVpnIP(vpnIP)
n.hostMap.DeleteIndex(hostinfo.localIndexId)
} else {
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
}
}
}

141
connection_manager_test.go Normal file
View File

@ -0,0 +1,141 @@
package nebula
import (
"net"
"testing"
"time"
"github.com/flynn/noise"
"github.com/stretchr/testify/assert"
"github.com/slackhq/nebula/cert"
)
var vpnIP uint32 = uint32(12341234)
func Test_NewConnectionManagerTest(t *testing.T) {
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects
hostMap := NewHostMap("test", vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
certificate: &cert.NebulaCertificate{},
rawCertificateNoKey: []byte{},
}
lh := NewLightHouse(false, 0, []string{}, 1000, 0, &udpConn{}, false)
ifce := &Interface{
hostMap: hostMap,
inside: &Tun{},
outside: &udpConn{},
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}),
}
now := time.Now()
// Create manager
nc := newConnectionManager(ifce, 5, 10)
nc.HandleMonitorTick(now)
// Add an ip we have established a connection w/ to hostmap
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
H: &noise.HandshakeState{},
messageCounter: new(uint64),
}
// We saw traffic out to vpnIP
nc.Out(vpnIP)
assert.NotContains(t, nc.pendingDeletion, vpnIP)
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
// Move ahead 5s. Nothing should happen
next_tick := now.Add(5 * time.Second)
nc.HandleMonitorTick(next_tick)
nc.HandleDeletionTick(next_tick)
// Move ahead 6s. We haven't heard back
next_tick = now.Add(6 * time.Second)
nc.HandleMonitorTick(next_tick)
nc.HandleDeletionTick(next_tick)
// This host should now be up for deletion
assert.Contains(t, nc.pendingDeletion, vpnIP)
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
// Move ahead some more
next_tick = now.Add(45 * time.Second)
nc.HandleMonitorTick(next_tick)
nc.HandleDeletionTick(next_tick)
// The host should be evicted
assert.NotContains(t, nc.pendingDeletion, vpnIP)
assert.NotContains(t, nc.hostMap.Hosts, vpnIP)
}
func Test_NewConnectionManagerTest2(t *testing.T) {
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects
hostMap := NewHostMap("test", vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
certificate: &cert.NebulaCertificate{},
rawCertificateNoKey: []byte{},
}
lh := NewLightHouse(false, 0, []string{}, 1000, 0, &udpConn{}, false)
ifce := &Interface{
hostMap: hostMap,
inside: &Tun{},
outside: &udpConn{},
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}),
}
now := time.Now()
// Create manager
nc := newConnectionManager(ifce, 5, 10)
nc.HandleMonitorTick(now)
// Add an ip we have established a connection w/ to hostmap
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
H: &noise.HandshakeState{},
messageCounter: new(uint64),
}
// We saw traffic out to vpnIP
nc.Out(vpnIP)
assert.NotContains(t, nc.pendingDeletion, vpnIP)
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
// Move ahead 5s. Nothing should happen
next_tick := now.Add(5 * time.Second)
nc.HandleMonitorTick(next_tick)
nc.HandleDeletionTick(next_tick)
// Move ahead 6s. We haven't heard back
next_tick = now.Add(6 * time.Second)
nc.HandleMonitorTick(next_tick)
nc.HandleDeletionTick(next_tick)
// This host should now be up for deletion
assert.Contains(t, nc.pendingDeletion, vpnIP)
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
// We heard back this time
nc.In(vpnIP)
// Move ahead some more
next_tick = now.Add(45 * time.Second)
nc.HandleMonitorTick(next_tick)
nc.HandleDeletionTick(next_tick)
// The host should be evicted
assert.NotContains(t, nc.pendingDeletion, vpnIP)
assert.Contains(t, nc.hostMap.Hosts, vpnIP)
}

75
connection_state.go Normal file
View File

@ -0,0 +1,75 @@
package nebula
import (
"crypto/rand"
"encoding/json"
"sync"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
)
const ReplayWindow = 1024
type ConnectionState struct {
eKey *NebulaCipherState
dKey *NebulaCipherState
H *noise.HandshakeState
certState *CertState
peerCert *cert.NebulaCertificate
initiator bool
messageCounter *uint64
window *Bits
queueLock sync.Mutex
writeLock sync.Mutex
ready bool
}
func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256)
if f.cipher == "chachapoly" {
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
}
curCertState := f.certState
static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey}
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
b.Update(0)
hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: cs,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: static,
PresharedKey: psk,
PresharedKeyPlacement: pskStage,
})
if err != nil {
return nil
}
// The queue and ready params prevent a counter race that would happen when
// sending stored packets and simultaneously accepting new traffic.
ci := &ConnectionState{
H: hs,
initiator: initiator,
window: b,
ready: false,
certState: curCertState,
messageCounter: new(uint64),
}
return ci
}
func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
return json.Marshal(m{
"certificate": cs.peerCert,
"initiator": cs.initiator,
"message_counter": cs.messageCounter,
"ready": cs.ready,
})
}

125
dns_server.go Normal file
View File

@ -0,0 +1,125 @@
package nebula
import (
"fmt"
"net"
"strconv"
"sync"
"github.com/miekg/dns"
)
// This whole thing should be rewritten to use context
var dnsR *dnsRecords
type dnsRecords struct {
sync.RWMutex
dnsMap map[string]string
hostMap *HostMap
}
func newDnsRecords(hostMap *HostMap) *dnsRecords {
return &dnsRecords{
dnsMap: make(map[string]string),
hostMap: hostMap,
}
}
func (d *dnsRecords) Query(data string) string {
d.RLock()
if r, ok := d.dnsMap[data]; ok {
d.RUnlock()
return r
}
d.RUnlock()
return ""
}
func (d *dnsRecords) QueryCert(data string) string {
ip := net.ParseIP(data[:len(data)-1])
if ip == nil {
return ""
}
iip := ip2int(ip)
hostinfo, err := d.hostMap.QueryVpnIP(iip)
if err != nil {
return ""
}
q := hostinfo.GetCert()
if q == nil {
return ""
}
cert := q.Details
c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAFter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
return c
}
func (d *dnsRecords) Add(host, data string) {
d.Lock()
d.dnsMap[host] = data
d.Unlock()
}
func parseQuery(m *dns.Msg, w dns.ResponseWriter) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeA:
l.Debugf("Query for A %s", q.Name)
ip := dnsR.Query(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeTXT:
a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
b := net.ParseIP(a)
// We don't answer these queries from non nebula nodes or localhost
//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
return
}
l.Debugf("Query for TXT %s", q.Name)
ip := dnsR.QueryCert(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
}
func handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Compress = false
switch r.Opcode {
case dns.OpcodeQuery:
parseQuery(m, w)
}
w.WriteMsg(m)
}
func dnsMain(hostMap *HostMap) {
dnsR = newDnsRecords(hostMap)
// attach request handler func
dns.HandleFunc(".", handleDnsRequest)
// start server
port := 53
server := &dns.Server{Addr: ":" + strconv.Itoa(port), Net: "udp"}
l.Debugf("Starting DNS responder at %d\n", port)
err := server.ListenAndServe()
defer server.Shutdown()
if err != nil {
l.Errorf("Failed to start server: %s\n ", err.Error())
}
}

19
dns_server_test.go Normal file
View File

@ -0,0 +1,19 @@
package nebula
import (
"testing"
"github.com/miekg/dns"
)
func TestParsequery(t *testing.T) {
//TODO: This test is basically pointless
hostMap := &HostMap{}
ds := newDnsRecords(hostMap)
ds.Add("test.com.com", "1.2.3.4")
m := new(dns.Msg)
m.SetQuestion("test.com.com", dns.TypeA)
//parseQuery(m)
}

160
examples/config.yaml Normal file
View File

@ -0,0 +1,160 @@
# This is the nebula example configuration file. You must edit, at a minimum, the static_host_map, lighthouse, and firewall sections
# PKI defines the location of credentials for this node. Each of these can also be inlined by using the yaml ": |" syntax.
pki:
ca: /etc/nebula/ca.crt
cert: /etc/nebula/host.crt
key: /etc/nebula/host.key
#blacklist is a list of certificate fingerprints that we will refuse to talk to
#blacklist:
# - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72
# The static host map defines a set of hosts with with fixed IP addresses on the internet (or any network).
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
# The syntax is:
# "{nebula ip}": ["{routable ip/dns name}:{routable port}"]
# Example, if your lighthouse has the nebula IP of 192.168.100.1 and has the real ip address of 100.64.22.11 and runs on port 4242:
static_host_map:
"192.168.100.1": ["100.64.22.11:4242"]
lighthouse:
# am_lighthouse is used to enable lighthouse functionality for a node. This should ONLY be true on nodes
# you have configured to be lighthouses in your network
am_lighthouse: false
# serve_dns optionally starts a dns listener that responds to various queries and can even be
# delegated to for resolution
#serve_dns: false
# interval is the number of seconds between updates from this node to a lighthouse.
# during updates, a node sends information about its current IP addresses to each node.
interval: 60
# hosts is a list of lighthouse hosts this node should report to and query from
# IMPORTANT: THIS SHOULD BE EMPTY ON LIGHTHOUSE NODES
hosts:
- "192.168.100.1"
# Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
# however using port 0 will dynamically assign a port and is recommended for roaming nodes.
listen:
host: 0.0.0.0
port: 4242
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
# default is 64, does not support reload
#batch: 64
# Configure socket buffers for the udp side (outside), leave unset to use the system defaults. Values will be doubled by the kernel
# Default is net.core.rmem_default and net.core.wmem_default (/proc/sys/net/core/rmem_default and /proc/sys/net/core/rmem_default)
# Maximum is limited by memory in the system, SO_RCVBUFFORCE and SO_SNDBUFFORCE is used to avoid having to raise the system wide
# max, net.core.rmem_max and net.core.wmem_max
#read_buffer: 10485760
#write_buffer: 10485760
# Local range is used to define a hint about the local network range, which speeds up discovering the fastest
# path to a network adjacent nebula node.
#local_range: "172.16.0.0/24"
# Handshake mac is an optional network-wide handshake authentication step that is used to prevent nebula from
# responding to handshakes from nodes not in possession of the shared secret. This is primarily used to prevent
# detection of nebula nodes when someone is scanning a network.
#handshake_mac:
#key: "DONOTUSETHISKEY"
# You can define multiple accepted keys
#accepted_keys:
#- "DONOTUSETHISKEY"
#- "dontusethiseither"
# sshd can expose informational and administrative functions via ssh this is a
#sshd:
# Toggles the feature
#enabled: true
# Host and port to listen on, port 22 is not allowed for your safety
#listen: 127.0.0.1:2222
# A file containing the ssh host private key to use
# A decent way to generate one: ssh-keygen -t ed25519 -f ssh_host_ed25519_key -N "" < /dev/null
#host_key: ./ssh_host_ed25519_key
# A file containing a list of authorized public keys
#authorized_users:
#- user: steeeeve
# keys can be an array of strings or single string
#keys:
#- "ssh public key string"
# Configure the private interface. Note: addr is baked into the nebula certificate
tun:
# Name of the device
dev: nebula1
# Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert
drop_local_broadcast: false
# Toggles forwarding of multicast packets
drop_multicast: false
# Sets the transmit queue length, if you notice lots of transmit drops on the tun it may help to raise this number. Default is 500
tx_queue: 500
# Default MTU for every packet, safe setting is (and the default) 1300 for internet based traffic
mtu: 1300
# Route based MTU overrides, you have known vpn ip paths that can support larger MTUs you can increase/decrease them here
routes:
#- mtu: 8800
# route: 10.0.0.0/16
# TODO
# Configure logging level
logging:
# panic, fatal, error, warning, info, or debug. Default is info
level: info
# json or text formats currently available. Default is text
format: text
#stats:
#type: graphite
#prefix: nebula
#protocol: tcp
#host: 127.0.0.1:9999
#interval: 10s
#type: prometheus
#listen: 127.0.0.1:8080
#path: /metrics
#namespace: prometheusns
#subsystem: nebula
#interval: 10s
# Nebula security group configuration
firewall:
conntrack:
tcp_timeout: 120h
udp_timeout: 3m
default_timeout: 10m
max_connections: 100000
# The firewall is default deny. There is no way to write a deny rule.
# Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
# Logical evaluation is roughly: port AND proto AND ca_sha AND ca_name AND (host OR group OR groups OR cidr)
# - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
# code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any`
# proto: `any`, `tcp`, `udp`, or `icmp`
# host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a CIDR, `0.0.0.0/0` is any.
# ca_name: An issuing CA name
# ca_sha: An issuing CA shasum
outbound:
# Allow all outbound traffic from this node
- port: any
proto: any
host: any
inbound:
# Allow icmp between any nebula hosts
- port: any
proto: icmp
host: any
# Allow tcp/443 from any host with BOTH laptop and home group
- port: 443
proto: tcp
groups:
- laptop
- home

View File

@ -0,0 +1,154 @@
# Quickstart Guide
This guide is intended to bring up a vagrant environment with 1 lighthouse and 2 generic hosts running nebula.
## Pre-requisites
There are two pre-requisites prior to bringing up the vagrant environment
- build the binaries locally for the vagrant deploy
- create a virtualenv for ansible
### Building the binaries
Build the `nebula` and `nebula-cert` binaries for vagrant by doing the following
`make bin-vagrant` (under the src directory with Makefile)
For convenience, ansible will run this for you in every deploy (see `ansible/playbook.yml`)
### Creating the virtualenv
Within the `quickstart/` directory, do the following
```
# make a virtual environment
virtualenv venv
# get into the virtualenv
source venv/bin/activate
# install ansible
pip install -r requirements.yml
```
## Bringing up the vagrant environment
A plugin that is used for the Vagrant environment is `vagrant-hostmanager`
To install, run
```
vagrant plugin install vagrant-hostmanager
```
All hosts within the Vagrantfile are brought up with
`vagrant up`
Once the boxes are up, go into the `ansible/` directory and deploy the playbook by running
`ansible-playbook playbook.yml -i inventory -u vagrant`
## Testing within the vagrant env
Once the ansible run is done, hop onto a vagrant box
`vagrant ssh generic1.vagrant`
or specifically
`ssh vagrant@<ip-address-in-vagrant-file` (password for the vagrant user on the boxes is `vagrant`)
Some quick tests once the vagrant boxes are up are to ping from `generic1.vagrant` to `generic2.vagrant` using
their respective nebula ip address.
```
vagrant@generic1:~$ ping 10.168.91.220
PING 10.168.91.220 (10.168.91.220) 56(84) bytes of data.
64 bytes from 10.168.91.220: icmp_seq=1 ttl=64 time=241 ms
64 bytes from 10.168.91.220: icmp_seq=2 ttl=64 time=0.704 ms
```
You can further verify that the allowed nebula firewall rules work by ssh'ing from 1 generic box to the other.
`ssh vagrant@<nebula-ip-address>` (password for the vagrant user on the boxes is `vagrant`)
See `/etc/nebula/config.yml` on a box for firewall rules.
To see full handshakes and hostmaps, change the logging config of `/etc/nebula/config.yml` on the vagrant boxes from
info to debug.
You can watch nebula logs by running
```
sudo journalctl -fu nebula
```
Refer to the nebula src code directory's README for further instructions on configuring nebula.
## Troubleshooting
### Is nebula up and running?
Run and verify that
```
ifconfig
```
shows you an interface with the name `nebula1` being up.
```
vagrant@generic1:~$ ifconfig nebula1
nebula1: flags=4305<UP,POINTOPOINT,RUNNING,NOARP,MULTICAST> mtu 1300
inet 10.168.91.210 netmask 255.128.0.0 destination 10.168.91.210
inet6 fe80::aeaf:b105:e6dc:936c prefixlen 64 scopeid 0x20<link>
unspec 00-00-00-00-00-00-00-00-00-00-00-00-00-00-00-00 txqueuelen 500 (UNSPEC)
RX packets 2 bytes 168 (168.0 B)
RX errors 0 dropped 0 overruns 0 frame 0
TX packets 11 bytes 600 (600.0 B)
TX errors 0 dropped 0 overruns 0 carrier 0 collisions 0
```
### Connectivity
Are you able to ping other boxes on the private nebula network?
The following are the private nebula ip addresses of the vagrant env
```
generic1.vagrant [nebula_ip] 10.168.91.210
generic2.vagrant [nebula_ip] 10.168.91.220
lighthouse1.vagrant [nebula_ip] 10.168.91.230
```
Try pinging generic1.vagrant to and from any other box using its nebula ip above.
Double check the nebula firewall rules under /etc/nebula/config.yml to make sure that connectivity is allowed for your use-case if on a specific port.
```
vagrant@lighthouse1:~$ grep -A21 firewall /etc/nebula/config.yml
firewall:
conntrack:
tcp_timeout: 12m
udp_timeout: 3m
default_timeout: 10m
max_connections: 100,000
inbound:
- proto: icmp
port: any
host: any
- proto: any
port: 22
host: any
- proto: any
port: 53
host: any
outbound:
- proto: any
port: any
host: any
```

40
examples/quickstart-vagrant/Vagrantfile vendored Normal file
View File

@ -0,0 +1,40 @@
Vagrant.require_version ">= 2.2.6"
nodes = [
{ :hostname => 'generic1.vagrant', :ip => '172.11.91.210', :box => 'bento/ubuntu-18.04', :ram => '512', :cpus => 1},
{ :hostname => 'generic2.vagrant', :ip => '172.11.91.220', :box => 'bento/ubuntu-18.04', :ram => '512', :cpus => 1},
{ :hostname => 'lighthouse1.vagrant', :ip => '172.11.91.230', :box => 'bento/ubuntu-18.04', :ram => '512', :cpus => 1},
]
Vagrant.configure("2") do |config|
config.ssh.insert_key = false
if Vagrant.has_plugin?('vagrant-cachier')
config.cache.enable :apt
else
printf("** Install vagrant-cachier plugin to speedup deploy: `vagrant plugin install vagrant-cachier`.**\n")
end
if Vagrant.has_plugin?('vagrant-hostmanager')
config.hostmanager.enabled = true
config.hostmanager.manage_host = true
config.hostmanager.include_offline = true
else
config.vagrant.plugins = "vagrant-hostmanager"
end
nodes.each do |node|
config.vm.define node[:hostname] do |node_config|
node_config.vm.box = node[:box]
node_config.vm.hostname = node[:hostname]
node_config.vm.network :private_network, ip: node[:ip]
node_config.vm.provider :virtualbox do |vb|
vb.memory = node[:ram]
vb.cpus = node[:cpus]
vb.customize ["modifyvm", :id, "--natdnshostresolver1", "on"]
vb.customize ['guestproperty', 'set', :id, '/VirtualBox/GuestAdd/VBoxService/--timesync-set-threshold', 10000]
end
end
end
end

View File

@ -0,0 +1,4 @@
[defaults]
host_key_checking = False
private_key_file = ~/.vagrant.d/insecure_private_key
become = yes

View File

@ -0,0 +1,21 @@
#!/usr/bin/python
class FilterModule(object):
def filters(self):
return {
'to_nebula_ip': self.to_nebula_ip,
'map_to_nebula_ips': self.map_to_nebula_ips,
}
def to_nebula_ip(self, ip_str):
ip_list = map(int, ip_str.split("."))
ip_list[0] = 10
ip_list[1] = 168
ip = '.'.join(map(str, ip_list))
return ip
def map_to_nebula_ips(self, ip_strs):
ip_list = [ self.to_nebula_ip(ip_str) for ip_str in ip_strs ]
ips = ', '.join(ip_list)
return ips

View File

@ -0,0 +1,11 @@
[all]
generic1.vagrant
generic2.vagrant
lighthouse1.vagrant
[generic]
generic1.vagrant
generic2.vagrant
[lighthouse]
lighthouse1.vagrant

View File

@ -0,0 +1,20 @@
---
- name: test connection to vagrant boxes
hosts: all
tasks:
- debug: msg=ok
- name: build nebula binaries locally
connection: local
hosts: localhost
tasks:
- command: chdir=../../../ make bin-vagrant
tags:
- build-nebula
- name: install nebula on all vagrant hosts
hosts: all
become: yes
gather_facts: yes
roles:
- nebula

View File

@ -0,0 +1,3 @@
---
# defaults file for nebula
nebula_config_directory: "/etc/nebula/"

View File

@ -0,0 +1,15 @@
[Unit]
Description=nebula
Wants=basic.target
After=basic.target network.target
[Service]
SyslogIdentifier=nebula
StandardOutput=syslog
StandardError=syslog
ExecReload=/bin/kill -HUP $MAINPID
ExecStart=/usr/local/bin/nebula -config /etc/nebula/config.yml
Restart=always
[Install]
WantedBy=multi-user.target

View File

@ -0,0 +1,5 @@
-----BEGIN NEBULA CERTIFICATE-----
CkAKDm5lYnVsYSB0ZXN0IENBKNXC1NYFMNXIhO0GOiCmVYeZ9tkB4WEnawmkrca+
hsAg9otUFhpAowZeJ33KVEABEkAORybHQUUyVFbKYzw0JHfVzAQOHA4kwB1yP9IV
KpiTw9+ADz+wA+R5tn9B+L8+7+Apc+9dem4BQULjA5mRaoYN
-----END NEBULA CERTIFICATE-----

View File

@ -0,0 +1,4 @@
-----BEGIN NEBULA ED25519 PRIVATE KEY-----
FEXZKMSmg8CgIODR0ymUeNT3nbnVpMi7nD79UgkCRHWmVYeZ9tkB4WEnawmkrca+
hsAg9otUFhpAowZeJ33KVA==
-----END NEBULA ED25519 PRIVATE KEY-----

View File

@ -0,0 +1,5 @@
---
# handlers file for nebula
- name: restart nebula
service: name=nebula state=restarted

View File

@ -0,0 +1,56 @@
---
# tasks file for nebula
- name: get the vagrant network interface and set fact
set_fact:
vagrant_ifce: "ansible_{{ ansible_interfaces | difference(['lo',ansible_default_ipv4.alias]) | sort | first }}"
tags:
- nebula-conf
- name: install built nebula binary
copy: src=../../../../../{{ item }} dest=/usr/local/bin mode=0755
with_items:
- nebula
- nebula-cert
- name: create nebula config directory
file: path="{{ nebula_config_directory }}" state=directory mode=0755
- name: temporarily copy over root.crt and root.key to sign
copy: src={{ item }} dest=/opt/{{ item }}
with_items:
- vagrant-test-ca.key
- vagrant-test-ca.crt
- name: sign using the root key
command: nebula-cert sign -ca-crt /opt/vagrant-test-ca.crt -ca-key /opt/vagrant-test-ca.key -duration 4320h -groups vagrant -ip {{ hostvars[inventory_hostname][vagrant_ifce]['ipv4']['address'] | to_nebula_ip }}/9 -name {{ ansible_hostname }}.nebula -out-crt /etc/nebula/host.crt -out-key /etc/nebula/host.key
- name: remove root.key used to sign
file: dest=/opt/{{ item }} state=absent
with_items:
- vagrant-test-ca.key
- name: write the content of the trusted ca certificate
copy: src="vagrant-test-ca.crt" dest="/etc/nebula/vagrant-test-ca.crt"
notify: restart nebula
- name: Create config directory
file: path="{{ nebula_config_directory }}" owner=root group=root mode=0755 state=directory
- name: nebula config
template: src=config.yml.j2 dest="/etc/nebula/config.yml" mode=0644 owner=root group=root
notify: restart nebula
tags:
- nebula-conf
- name: nebula systemd
copy: src=systemd.nebula.service dest="/etc/systemd/system/nebula.service" mode=0644 owner=root group=root
register: addconf
notify: restart nebula
- name: maybe reload systemd
shell: systemctl daemon-reload
when: addconf.changed
- name: nebula running
service: name="nebula" state=started enabled=yes

View File

@ -0,0 +1,84 @@
pki:
ca: /etc/nebula/vagrant-test-ca.crt
cert: /etc/nebula/host.crt
key: /etc/nebula/host.key
# Port Nebula will be listening on
listen:
host: 0.0.0.0
port: 4242
# sshd can expose informational and administrative functions via ssh
sshd:
# Toggles the feature
enabled: true
# Host and port to listen on
listen: 127.0.0.1:2222
# A file containing the ssh host private key to use
host_key: /etc/ssh/ssh_host_ed25519_key
# A file containing a list of authorized public keys
authorized_users:
{% for user in nebula_users %}
- user: {{ user.name }}
keys:
{% for key in user.ssh_auth_keys %}
- "{{ key }}"
{% endfor %}
{% endfor %}
local_range: 10.168.0.0/16
static_host_map:
# lighthouse
{{ hostvars[groups['lighthouse'][0]][vagrant_ifce]['ipv4']['address'] | to_nebula_ip }}: ["{{ hostvars[groups['lighthouse'][0]][vagrant_ifce]['ipv4']['address']}}:4242"]
default_route: "0.0.0.0"
lighthouse:
{% if 'lighthouse' in group_names %}
am_lighthouse: true
serve_dns: true
{% else %}
am_lighthouse: false
{% endif %}
interval: 60
hosts:
- {{ hostvars[groups['lighthouse'][0]][vagrant_ifce]['ipv4']['address'] | to_nebula_ip }}
# Configure the private interface
tun:
dev: nebula1
# Sets MTU of the tun dev.
# MTU of the tun must be smaller than the MTU of the eth0 interface
mtu: 1300
# TODO
# Configure logging level
logging:
level: info
format: json
firewall:
conntrack:
tcp_timeout: 12m
udp_timeout: 3m
default_timeout: 10m
max_connections: 100,000
inbound:
- proto: icmp
port: any
host: any
- proto: any
port: 22
host: any
{% if "lighthouse" in groups %}
- proto: any
port: 53
host: any
{% endif %}
outbound:
- proto: any
port: any
host: any

View File

@ -0,0 +1,7 @@
---
# vars file for nebula
nebula_users:
- name: user1
ssh_auth_keys:
- "ed25519 place-your-ssh-public-key-here"

View File

@ -0,0 +1 @@
ansible

View File

@ -0,0 +1,51 @@
#!/bin/sh
### BEGIN INIT INFO
# Provides: nebula
# Required-Start: $local_fs $network
# Required-Stop: $local_fs $network
# Default-Start: 2 3 4 5
# Default-Stop: 0 1 6
# Description: nebula mesh vpn client
### END INIT INFO
SCRIPT="/usr/local/bin/nebula -config /etc/nebula/config.yml"
RUNAS=root
PIDFILE=/var/run/nebula.pid
LOGFILE=/var/log/nebula.log
start() {
if [ -f $PIDFILE ] && kill -0 $(cat $PIDFILE); then
echo 'Service already running' >&2
return 1
fi
echo 'Starting nebula service…' >&2
local CMD="$SCRIPT &> \"$LOGFILE\" & echo \$!"
su -c "$CMD" $RUNAS > "$PIDFILE"
echo 'Service started' >&2
}
stop() {
if [ ! -f "$PIDFILE" ] || ! kill -0 $(cat "$PIDFILE"); then
echo 'Service not running' >&2
return 1
fi
echo 'Stopping nebula service…' >&2
kill -15 $(cat "$PIDFILE") && rm -f "$PIDFILE"
echo 'Service stopped' >&2
}
case "$1" in
start)
start
;;
stop)
stop
;;
restart)
stop
start
;;
*)
echo "Usage: $0 {start|stop|restart}"
esac

View File

@ -0,0 +1,15 @@
[Unit]
Description=nebula
Wants=basic.target
After=basic.target network.target
[Service]
SyslogIdentifier=nebula
StandardOutput=syslog
StandardError=syslog
ExecReload=/bin/kill -HUP $MAINPID
ExecStart=/usr/local/bin/nebula -config /etc/nebula/config.yml
Restart=always
[Install]
WantedBy=multi-user.target

789
firewall.go Normal file
View File

@ -0,0 +1,789 @@
package nebula
import (
"encoding/binary"
"encoding/json"
"fmt"
"net"
"sync"
"time"
"crypto/sha256"
"encoding/hex"
"errors"
"reflect"
"strconv"
"strings"
"github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
)
const (
fwProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
fwProtoTCP = 6
fwProtoUDP = 17
fwProtoICMP = 1
fwPortAny = 0 // Special value for matching `port: any`
fwPortFragment = -1 // Special value for matching `port: fragment`
)
const tcpACK = 0x10
const tcpFIN = 0x01
type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error
}
type conn struct {
Expires time.Time // Time when this conntrack entry will expire
Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
}
// TODO: need conntrack max tracked connections handling
type Firewall struct {
Conns map[FirewallPacket]*conn
InRules *FirewallTable
OutRules *FirewallTable
//TODO: we should have many more options for TCP, an option for ICMP, and mimic the kernel a bit better
// https://www.kernel.org/doc/Documentation/networking/nf_conntrack-sysctl.txt
TCPTimeout time.Duration //linux: 5 days max
UDPTimeout time.Duration //linux: 180s max
DefaultTimeout time.Duration //linux: 600s
TimerWheel *TimerWheel
// Used to ensure we don't emit local packets for ips we don't own
localIps *CIDRTree
connMutex sync.Mutex
rules string
trackTCPRTT bool
metricTCPRTT metrics.Histogram
}
type FirewallTable struct {
TCP firewallPort
UDP firewallPort
ICMP firewallPort
AnyProto firewallPort
}
func newFirewallTable() *FirewallTable {
return &FirewallTable{
TCP: firewallPort{},
UDP: firewallPort{},
ICMP: firewallPort{},
AnyProto: firewallPort{},
}
}
type FirewallRule struct {
// Any makes Hosts, Groups, and CIDR irrelevant. CAName and CASha still need to be checked
Any bool
Hosts map[string]struct{}
Groups [][]string
CIDR *CIDRTree
CANames map[string]struct{}
CAShas map[string]struct{}
}
// Even though ports are uint16, int32 maps are faster for lookup
// Plus we can use `-1` for fragment rules
type firewallPort map[int32]*FirewallRule
type FirewallPacket struct {
LocalIP uint32
RemoteIP uint32
LocalPort uint16
RemotePort uint16
Protocol uint8
Fragment bool
}
func (fp *FirewallPacket) Copy() *FirewallPacket {
return &FirewallPacket{
LocalIP: fp.LocalIP,
RemoteIP: fp.RemoteIP,
LocalPort: fp.LocalPort,
RemotePort: fp.RemotePort,
Protocol: fp.Protocol,
Fragment: fp.Fragment,
}
}
func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
var proto string
switch fp.Protocol {
case fwProtoTCP:
proto = "tcp"
case fwProtoICMP:
proto = "icmp"
case fwProtoUDP:
proto = "udp"
default:
proto = fmt.Sprintf("unknown %v", fp.Protocol)
}
return json.Marshal(m{
"LocalIP": int2ip(fp.LocalIP).String(),
"RemoteIP": int2ip(fp.RemoteIP).String(),
"LocalPort": fp.LocalPort,
"RemotePort": fp.RemotePort,
"Protocol": proto,
"Fragment": fp.Fragment,
})
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
//TODO: error on 0 duration
var min, max time.Duration
if tcpTimeout < UDPTimeout {
min = tcpTimeout
max = UDPTimeout
} else {
min = UDPTimeout
max = tcpTimeout
}
if defaultTimeout < min {
min = defaultTimeout
} else if defaultTimeout > max {
max = defaultTimeout
}
localIps := NewCIDRTree()
for _, ip := range c.Details.Ips {
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
}
for _, n := range c.Details.Subnets {
localIps.AddCIDR(n, struct{}{})
}
return &Firewall{
Conns: make(map[FirewallPacket]*conn),
InRules: newFirewallTable(),
OutRules: newFirewallTable(),
TimerWheel: NewTimerWheel(min, max),
TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
localIps: localIps,
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
}
}
func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
fw := NewFirewall(
c.GetDuration("firewall.conntrack.tcp_timeout", time.Duration(time.Minute*12)),
c.GetDuration("firewall.conntrack.udp_timeout", time.Duration(time.Minute*3)),
c.GetDuration("firewall.conntrack.default_timeout", time.Duration(time.Minute*10)),
nc,
//TODO: max_connections
)
err := AddFirewallRulesFromConfig(false, c, fw)
if err != nil {
return nil, err
}
err = AddFirewallRulesFromConfig(true, c, fw)
if err != nil {
return nil, err
}
return fw, nil
}
// AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
// We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, ip, caName, caSha,
)
f.rules += ruleString + "\n"
direction := "incoming"
if !incoming {
direction = "outgoing"
}
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": ip, "caName": caName, "caSha": caSha}).
Info("Firewall rule added")
var (
ft *FirewallTable
fp firewallPort
)
if incoming {
ft = f.InRules
} else {
ft = f.OutRules
}
switch proto {
case fwProtoTCP:
fp = ft.TCP
case fwProtoUDP:
fp = ft.UDP
case fwProtoICMP:
fp = ft.ICMP
case fwProtoAny:
fp = ft.AnyProto
default:
return fmt.Errorf("unknown protocol %v", proto)
}
return fp.addRule(startPort, endPort, groups, host, ip, caName, caSha)
}
// GetRuleHash returns a hash representation of all inbound and outbound rules
func (f *Firewall) GetRuleHash() string {
sum := sha256.Sum256([]byte(f.rules))
return hex.EncodeToString(sum[:])
}
func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterface) error {
var table string
if inbound {
table = "firewall.inbound"
} else {
table = "firewall.outbound"
}
r := config.Get(table)
if r == nil {
return nil
}
rs, ok := r.([]interface{})
if !ok {
return fmt.Errorf("%s failed to parse, should be an array of rules", table)
}
for i, t := range rs {
var groups []string
r, err := convertRule(t)
if err != nil {
return fmt.Errorf("%s rule #%v; %s", table, i, err)
}
if r.Code != "" && r.Port != "" {
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
}
if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.CAName == "" && r.CASha == "" {
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, ca_name, or ca_sha must be provided", table, i)
}
if len(r.Groups) > 0 {
groups = r.Groups
}
if r.Group != "" {
// Check if we have both groups and group provided in the rule config
if len(groups) > 0 {
return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i)
}
groups = []string{r.Group}
}
var sPort, errPort string
if r.Code != "" {
errPort = "code"
sPort = r.Code
} else {
errPort = "port"
sPort = r.Port
}
startPort, endPort, err := parsePort(sPort)
if err != nil {
return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err)
}
var proto uint8
switch r.Proto {
case "any":
proto = fwProtoAny
case "tcp":
proto = fwProtoTCP
case "udp":
proto = fwProtoUDP
case "icmp":
proto = fwProtoICMP
default:
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
}
var cidr *net.IPNet
if r.Cidr != "" {
_, cidr, err = net.ParseCIDR(r.Cidr)
if err != nil {
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
}
}
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, r.CAName, r.CASha)
if err != nil {
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
}
}
return nil
}
func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(packet, fp, incoming) {
return false
}
// Make sure we are supposed to be handling this local ip address
if f.localIps.Contains(fp.LocalIP) == nil {
return true
}
table := f.OutRules
if incoming {
table = f.InRules
}
// We now know which firewall table to check against
if !table.match(fp, incoming, c, caPool) {
return true
}
// We always want to conntrack since it is a faster operation
f.addConn(packet, fp, incoming)
return false
}
// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new
// firewall object is created
func (f *Firewall) Destroy() {
//TODO: clean references if/when needed
}
func (f *Firewall) EmitStats() {
conntrackCount := len(f.Conns)
metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
}
func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool {
f.connMutex.Lock()
// Purge every time we test
ep, has := f.TimerWheel.Purge()
if has {
f.evict(ep)
}
c, ok := f.Conns[fp]
if !ok {
f.connMutex.Unlock()
return false
}
switch fp.Protocol {
case fwProtoTCP:
c.Expires = time.Now().Add(f.TCPTimeout)
if incoming {
f.checkTCPRTT(c, packet)
} else {
setTCPRTTTracking(c, packet)
}
case fwProtoUDP:
c.Expires = time.Now().Add(f.UDPTimeout)
default:
c.Expires = time.Now().Add(f.DefaultTimeout)
}
f.connMutex.Unlock()
return true
}
func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
var timeout time.Duration
c := &conn{}
switch fp.Protocol {
case fwProtoTCP:
timeout = f.TCPTimeout
if !incoming {
setTCPRTTTracking(c, packet)
}
case fwProtoUDP:
timeout = f.UDPTimeout
default:
timeout = f.DefaultTimeout
}
f.connMutex.Lock()
if _, ok := f.Conns[fp]; !ok {
f.TimerWheel.Add(fp, timeout)
}
c.Expires = time.Now().Add(timeout)
f.Conns[fp] = c
f.connMutex.Unlock()
}
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
// Caller must own the connMutex lock!
func (f *Firewall) evict(p FirewallPacket) {
//TODO: report a stat if the tcp rtt tracking was never resolved?
// Are we still tracking this conn?
t, ok := f.Conns[p]
if !ok {
return
}
newT := t.Expires.Sub(time.Now())
// Timeout is in the future, re-add the timer
if newT > 0 {
f.TimerWheel.Add(p, newT)
return
}
// This conn is done
delete(f.Conns, p)
}
func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
if ft.AnyProto.match(p, incoming, c, caPool) {
return true
}
switch p.Protocol {
case fwProtoTCP:
if ft.TCP.match(p, incoming, c, caPool) {
return true
}
case fwProtoUDP:
if ft.UDP.match(p, incoming, c, caPool) {
return true
}
case fwProtoICMP:
if ft.ICMP.match(p, incoming, c, caPool) {
return true
}
}
return false
}
func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
if startPort > endPort {
return fmt.Errorf("start port was lower than end port")
}
for i := startPort; i <= endPort; i++ {
if _, ok := fp[i]; !ok {
fp[i] = &FirewallRule{
Groups: make([][]string, 0),
Hosts: make(map[string]struct{}),
CIDR: NewCIDRTree(),
CANames: make(map[string]struct{}),
CAShas: make(map[string]struct{}),
}
}
if err := fp[i].addRule(groups, host, ip, caName, caSha); err != nil {
return err
}
}
return nil
}
func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
// We don't have any allowed ports, bail
if fp == nil {
return false
}
var port int32
if p.Fragment {
port = fwPortFragment
} else if incoming {
port = int32(p.LocalPort)
} else {
port = int32(p.RemotePort)
}
if fp[port].match(p, c, caPool) {
return true
}
return fp[fwPortAny].match(p, c, caPool)
}
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
if caName != "" {
fr.CANames[caName] = struct{}{}
}
if caSha != "" {
fr.CAShas[caSha] = struct{}{}
}
if fr.Any {
return nil
}
if fr.isAny(groups, host, ip) {
fr.Any = true
// If it's any we need to wipe out any pre-existing rules to save on memory
fr.Groups = make([][]string, 0)
fr.Hosts = make(map[string]struct{})
fr.CIDR = NewCIDRTree()
} else {
if len(groups) > 0 {
fr.Groups = append(fr.Groups, groups)
}
if host != "" {
fr.Hosts[host] = struct{}{}
}
if ip != nil {
fr.CIDR.AddCIDR(ip, struct{}{})
}
}
return nil
}
func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
for _, group := range groups {
if group == "any" {
return true
}
}
if host == "any" {
return true
}
if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) {
return true
}
return false
}
func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
if fr == nil {
return false
}
// CASha and CAName always need to be checked
if len(fr.CAShas) > 0 {
if _, ok := fr.CAShas[c.Details.Issuer]; !ok {
return false
}
}
if len(fr.CANames) > 0 {
s, err := caPool.GetCAForCert(c)
if err != nil {
return false
}
if _, ok := fr.CANames[s.Details.Name]; !ok {
return false
}
}
// Shortcut path for if groups, hosts, or cidr contained an `any`
if fr.Any {
return true
}
// Need any of group, host, or cidr to match
for _, sg := range fr.Groups {
found := false
for _, g := range sg {
if _, ok := c.Details.InvertedGroups[g]; !ok {
found = false
break
}
found = true
}
if found {
return true
}
}
if fr.Hosts != nil {
if _, ok := fr.Hosts[c.Details.Name]; ok {
return true
}
}
if fr.CIDR != nil && fr.CIDR.Contains(p.RemoteIP) != nil {
return true
}
// No host, group, or cidr matched, bye bye
return false
}
type rule struct {
Port string
Code string
Proto string
Host string
Group string
Groups []string
Cidr string
CAName string
CASha string
}
func convertRule(p interface{}) (rule, error) {
r := rule{}
m, ok := p.(map[interface{}]interface{})
if !ok {
return r, errors.New("could not parse rule")
}
toString := func(k string, m map[interface{}]interface{}) string {
v, ok := m[k]
if !ok {
return ""
}
return fmt.Sprintf("%v", v)
}
r.Port = toString("port", m)
r.Code = toString("code", m)
r.Proto = toString("proto", m)
r.Host = toString("host", m)
r.Group = toString("group", m)
r.Cidr = toString("cidr", m)
r.CAName = toString("ca_name", m)
r.CASha = toString("ca_sha", m)
if rg, ok := m["groups"]; ok {
switch reflect.TypeOf(rg).Kind() {
case reflect.Slice:
v := reflect.ValueOf(rg)
r.Groups = make([]string, v.Len())
for i := 0; i < v.Len(); i++ {
r.Groups[i] = v.Index(i).Interface().(string)
}
case reflect.String:
r.Groups = []string{rg.(string)}
default:
r.Groups = []string{fmt.Sprintf("%v", rg)}
}
}
return r, nil
}
func parsePort(s string) (startPort, endPort int32, err error) {
if s == "any" {
startPort = fwPortAny
endPort = fwPortAny
} else if s == "fragment" {
startPort = fwPortFragment
endPort = fwPortFragment
} else if strings.Contains(s, `-`) {
sPorts := strings.SplitN(s, `-`, 2)
sPorts[0] = strings.Trim(sPorts[0], " ")
sPorts[1] = strings.Trim(sPorts[1], " ")
if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" {
return 0, 0, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
}
rStartPort, err := strconv.Atoi(sPorts[0])
if err != nil {
return 0, 0, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
}
rEndPort, err := strconv.Atoi(sPorts[1])
if err != nil {
return 0, 0, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
}
startPort = int32(rStartPort)
endPort = int32(rEndPort)
if startPort == fwPortAny {
endPort = fwPortAny
}
} else {
rPort, err := strconv.Atoi(s)
if err != nil {
return 0, 0, fmt.Errorf("was not a number; `%s`", s)
}
startPort = int32(rPort)
endPort = startPort
}
return
}
//TODO: write tests for these
func setTCPRTTTracking(c *conn, p []byte) {
if c.Seq != 0 {
return
}
ihl := int(p[0]&0x0f) << 2
// Don't track FIN packets
if uint8(p[ihl+13])&tcpFIN != 0 {
return
}
c.Seq = binary.BigEndian.Uint32(p[ihl+4 : ihl+8])
c.Sent = time.Now()
}
func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
if c.Seq == 0 {
return false
}
ihl := int(p[0]&0x0f) << 2
if uint8(p[ihl+13])&tcpACK == 0 {
return false
}
// Deal with wrap around, signed int cuts the ack window in half
// 0 is a bad ack, no data acknowledged
// positive number is a bad ack, ack is over half the window away
if int32(c.Seq-binary.BigEndian.Uint32(p[ihl+8:ihl+12])) >= 0 {
return false
}
f.metricTCPRTT.Update(time.Since(c.Sent).Nanoseconds())
c.Seq = 0
return true
}

687
firewall_test.go Normal file
View File

@ -0,0 +1,687 @@
package nebula
import (
"encoding/binary"
"errors"
"github.com/rcrowley/go-metrics"
"github.com/stretchr/testify/assert"
"math"
"net"
"github.com/slackhq/nebula/cert"
"testing"
"time"
)
func TestNewFirewall(t *testing.T) {
c := &cert.NebulaCertificate{}
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
assert.NotNil(t, fw.Conns)
assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules)
assert.NotNil(t, fw.TimerWheel)
assert.Equal(t, time.Second, fw.TCPTimeout)
assert.Equal(t, time.Minute, fw.UDPTimeout)
assert.Equal(t, time.Hour, fw.DefaultTimeout)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
}
func TestFirewall_AddRule(t *testing.T) {
c := &cert.NebulaCertificate{}
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules)
_, ti, _ := net.ParseCIDR("1.2.3.4/32")
assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
// Make sure an empty rule creates structure but doesn't allow anything to flow
//TODO: ideally an empty rule would return an error
assert.False(t, fw.InRules.TCP[1].Any)
assert.Empty(t, fw.InRules.TCP[1].Groups)
assert.Empty(t, fw.InRules.TCP[1].Hosts)
assert.Nil(t, fw.InRules.TCP[1].CIDR.root.left)
assert.Nil(t, fw.InRules.TCP[1].CIDR.root.right)
assert.Nil(t, fw.InRules.TCP[1].CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
assert.False(t, fw.InRules.UDP[1].Any)
assert.Contains(t, fw.InRules.UDP[1].Groups[0], "g1")
assert.Empty(t, fw.InRules.UDP[1].Hosts)
assert.Nil(t, fw.InRules.UDP[1].CIDR.root.left)
assert.Nil(t, fw.InRules.UDP[1].CIDR.root.right)
assert.Nil(t, fw.InRules.UDP[1].CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
assert.False(t, fw.InRules.ICMP[1].Any)
assert.Empty(t, fw.InRules.ICMP[1].Groups)
assert.Contains(t, fw.InRules.ICMP[1].Hosts, "h1")
assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.left)
assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.right)
assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Groups)
assert.Empty(t, fw.OutRules.AnyProto[1].Hosts)
assert.NotNil(t, fw.OutRules.AnyProto[1].CIDR.Match(ip2int(ti.IP)))
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
// Set any and clear fields
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Groups[0])
assert.Contains(t, fw.OutRules.AnyProto[0].Hosts, "h1")
assert.NotNil(t, fw.OutRules.AnyProto[0].CIDR.Match(ip2int(ti.IP)))
// run twice just to make sure
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any)
assert.Empty(t, fw.OutRules.AnyProto[0].Groups)
assert.Empty(t, fw.OutRules.AnyProto[0].Hosts)
assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.left)
assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.right)
assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any)
// Test error conditions
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
}
func TestFirewall_Drop(t *testing.T) {
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
101,
10,
90,
fwProtoUDP,
false,
}
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
c := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host1",
Ips: []*net.IPNet{&ipNet},
Groups: []string{"default-group"},
Issuer: "signer-shasum",
},
}
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.True(t, fw.Drop([]byte{}, p, false, &c, cp))
// Allow inbound
assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
// Allow outbound because conntrack
assert.False(t, fw.Drop([]byte{}, p, false, &c, cp))
// test caSha assertions true
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum"))
assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
// test caSha assertions false
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum-nope"))
assert.True(t, fw.Drop([]byte{}, p, true, &c, cp))
// test caName true
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-good", ""))
assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
// test caName false
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-bad", ""))
assert.True(t, fw.Drop([]byte{}, p, true, &c, cp))
}
func BenchmarkFirewallTable_match(b *testing.B) {
ft := FirewallTable{
TCP: firewallPort{},
}
_, n, _ := net.ParseCIDR("172.1.1.1/32")
ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) {
c := &cert.NebulaCertificate{}
for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoUDP}, true, c, cp)
}
})
b.Run("fail on port", func(b *testing.B) {
c := &cert.NebulaCertificate{}
for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 1}, true, c, cp)
}
})
b.Run("fail all group, name, and cidr", func(b *testing.B) {
_, ip, _ := net.ParseCIDR("9.254.254.254/32")
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "nope",
Ips: []*net.IPNet{ip},
},
}
for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
}
})
b.Run("pass on group", func(b *testing.B) {
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"good-group": {}},
Name: "nope",
},
}
for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
}
})
b.Run("pass on name", func(b *testing.B) {
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host",
},
}
for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
}
})
b.Run("pass on ip", func(b *testing.B) {
ip := ip2int(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host",
},
}
for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
}
})
ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
b.Run("pass on ip with any port", func(b *testing.B) {
ip := ip2int(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host",
},
}
for n := 0; n < b.N; n++ {
ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
}
})
}
func TestFirewall_Drop2(t *testing.T) {
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
101,
10,
90,
fwProtoUDP,
false,
}
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
c := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host1",
Ips: []*net.IPNet{&ipNet},
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
},
}
c1 := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host1",
Ips: []*net.IPNet{&ipNet},
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
},
}
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
cp := cert.NewCAPool()
// c1 lacks the proper groups
assert.True(t, fw.Drop([]byte{}, p, true, &c1, cp))
// c has the proper groups
assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
}
func BenchmarkLookup(b *testing.B) {
ml := func(m map[string]struct{}, a [][]string) {
for n := 0; n < b.N; n++ {
for _, sg := range a {
found := false
for _, g := range sg {
if _, ok := m[g]; !ok {
found = false
break
}
found = true
}
if found {
return
}
}
}
}
b.Run("array to map best", func(b *testing.B) {
m := map[string]struct{}{
"1ne": {},
"2wo": {},
"3hr": {},
"4ou": {},
"5iv": {},
"6ix": {},
}
a := [][]string{
{"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
{"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
{"one", "two", "3hr", "4ou", "5iv", "6ix"},
{"one", "two", "thr", "4ou", "5iv", "6ix"},
{"one", "two", "thr", "fou", "5iv", "6ix"},
{"one", "two", "thr", "fou", "fiv", "6ix"},
{"one", "two", "thr", "fou", "fiv", "six"},
}
for n := 0; n < b.N; n++ {
ml(m, a)
}
})
b.Run("array to map worst", func(b *testing.B) {
m := map[string]struct{}{
"one": {},
"two": {},
"thr": {},
"fou": {},
"fiv": {},
"six": {},
}
a := [][]string{
{"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
{"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
{"one", "two", "3hr", "4ou", "5iv", "6ix"},
{"one", "two", "thr", "4ou", "5iv", "6ix"},
{"one", "two", "thr", "fou", "5iv", "6ix"},
{"one", "two", "thr", "fou", "fiv", "6ix"},
{"one", "two", "thr", "fou", "fiv", "six"},
}
for n := 0; n < b.N; n++ {
ml(m, a)
}
})
//TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
}
func Test_parsePort(t *testing.T) {
_, _, err := parsePort("")
assert.EqualError(t, err, "was not a number; ``")
_, _, err = parsePort(" ")
assert.EqualError(t, err, "was not a number; ` `")
_, _, err = parsePort("-")
assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
_, _, err = parsePort(" - ")
assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
_, _, err = parsePort("a-b")
assert.EqualError(t, err, "beginning range was not a number; `a`")
_, _, err = parsePort("1-b")
assert.EqualError(t, err, "ending range was not a number; `b`")
s, e, err := parsePort(" 1 - 2 ")
assert.Equal(t, int32(1), s)
assert.Equal(t, int32(2), e)
assert.Nil(t, err)
s, e, err = parsePort("0-1")
assert.Equal(t, int32(0), s)
assert.Equal(t, int32(0), e)
assert.Nil(t, err)
s, e, err = parsePort("9919")
assert.Equal(t, int32(9919), s)
assert.Equal(t, int32(9919), e)
assert.Nil(t, err)
s, e, err = parsePort("any")
assert.Equal(t, int32(0), s)
assert.Equal(t, int32(0), e)
assert.Nil(t, err)
}
func TestNewFirewallFromConfig(t *testing.T) {
// Test a bad rule definition
c := &cert.NebulaCertificate{}
conf := NewConfig()
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
_, err := NewFirewallFromConfig(c, conf)
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code
conf = NewConfig()
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha
conf = NewConfig()
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
_, err = NewFirewallFromConfig(c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
// Test code/port error
conf = NewConfig()
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error
conf = NewConfig()
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error
conf = NewConfig()
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
// Test both group and groups
conf = NewConfig()
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(c, conf)
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
}
func TestAddFirewallRulesFromConfig(t *testing.T) {
// Test adding tcp rule
conf := NewConfig()
mf := &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding udp rule
conf = NewConfig()
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding icmp rule
conf = NewConfig()
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding any rule
conf = NewConfig()
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding rule with ca_sha
conf = NewConfig()
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = NewConfig()
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
// Test single group
conf = NewConfig()
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
// Test single groups
conf = NewConfig()
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
// Test multiple AND groups
conf = NewConfig()
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
// Test Add error
conf = NewConfig()
mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.EqualError(t, AddFirewallRulesFromConfig(true, conf, mf), "firewall.inbound rule #0; `test error`")
}
func TestTCPRTTTracking(t *testing.T) {
b := make([]byte, 200)
// Max ip IHL (60 bytes) and tcp IHL (60 bytes)
b[0] = 15
b[60+12] = 15 << 4
f := Firewall{
metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)),
}
// Set SEQ to 1
binary.BigEndian.PutUint32(b[60+4:60+8], 1)
c := &conn{}
setTCPRTTTracking(c, b)
assert.Equal(t, uint32(1), c.Seq)
// Bad ack - no ack flag
binary.BigEndian.PutUint32(b[60+8:60+12], 80)
assert.False(t, f.checkTCPRTT(c, b))
// Bad ack, number is too low
binary.BigEndian.PutUint32(b[60+8:60+12], 0)
b[60+13] = uint8(0x10)
assert.False(t, f.checkTCPRTT(c, b))
// Good ack
binary.BigEndian.PutUint32(b[60+8:60+12], 80)
assert.True(t, f.checkTCPRTT(c, b))
assert.Equal(t, uint32(0), c.Seq)
// Set SEQ to 1
binary.BigEndian.PutUint32(b[60+4:60+8], 1)
c = &conn{}
setTCPRTTTracking(c, b)
assert.Equal(t, uint32(1), c.Seq)
// Good acks
binary.BigEndian.PutUint32(b[60+8:60+12], 81)
assert.True(t, f.checkTCPRTT(c, b))
assert.Equal(t, uint32(0), c.Seq)
// Set SEQ to max uint32 - 20
binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20)
c = &conn{}
setTCPRTTTracking(c, b)
assert.Equal(t, ^uint32(0)-20, c.Seq)
// Good acks
binary.BigEndian.PutUint32(b[60+8:60+12], 81)
assert.True(t, f.checkTCPRTT(c, b))
assert.Equal(t, uint32(0), c.Seq)
// Set SEQ to max uint32 / 2
binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2)
c = &conn{}
setTCPRTTTracking(c, b)
assert.Equal(t, ^uint32(0)/2, c.Seq)
// Below
binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1)
assert.False(t, f.checkTCPRTT(c, b))
assert.Equal(t, ^uint32(0)/2, c.Seq)
// Halfway below
binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0))
assert.False(t, f.checkTCPRTT(c, b))
assert.Equal(t, ^uint32(0)/2, c.Seq)
// Halfway above is ok
binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0))
assert.True(t, f.checkTCPRTT(c, b))
assert.Equal(t, uint32(0), c.Seq)
// Set SEQ to max uint32
binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0))
c = &conn{}
setTCPRTTTracking(c, b)
assert.Equal(t, ^uint32(0), c.Seq)
// Halfway + 1 above
binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1)
assert.False(t, f.checkTCPRTT(c, b))
assert.Equal(t, ^uint32(0), c.Seq)
// Halfway above
binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2)
assert.True(t, f.checkTCPRTT(c, b))
assert.Equal(t, uint32(0), c.Seq)
}
type addRuleCall struct {
incoming bool
proto uint8
startPort int32
endPort int32
groups []string
host string
ip *net.IPNet
caName string
caSha string
}
type mockFirewall struct {
lastCall addRuleCall
nextCallReturn error
}
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
mf.lastCall = addRuleCall{
incoming: incoming,
proto: proto,
startPort: startPort,
endPort: endPort,
groups: groups,
host: host,
ip: ip,
caName: caName,
caSha: caSha,
}
err := mf.nextCallReturn
mf.nextCallReturn = nil
return err
}

32
go.mod Normal file
View File

@ -0,0 +1,32 @@
module github.com/slackhq/nebula
go 1.12
require (
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6
github.com/golang/protobuf v1.3.1
github.com/imdario/mergo v0.3.7
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
github.com/kr/pretty v0.1.0 // indirect
github.com/miekg/dns v1.1.12
github.com/nbrownus/go-metrics-prometheus v0.0.0-20180622211546-6e6d5173d99c
github.com/prometheus/client_golang v0.9.3
github.com/prometheus/common v0.4.1 // indirect
github.com/prometheus/procfs v0.0.0-20190523193104-a7aeb8df3389 // indirect
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a
github.com/sirupsen/logrus v1.4.2
github.com/songgao/water v0.0.0-20190402020555-6ad6edefb15c
github.com/stretchr/testify v1.3.0
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc // indirect
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f
golang.org/x/net v0.0.0-20190522155817-f3200d17e092
golang.org/x/sync v0.0.0-20190423024810-112230192c58 // indirect
golang.org/x/sys v0.0.0-20190524152521-dbbf3f1254d4
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/yaml.v2 v2.2.2
)

112
go.sum Normal file
View File

@ -0,0 +1,112 @@
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239 h1:kFOfPq6dUM1hTo4JG6LR5AXSUEsOjtdm0kw0FtQtMJA=
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c=
github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI=
github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps=
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6 h1:u/UEqS66A5ckRmS4yNpjmVH56sVtS/RfclBAYocb4as=
github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6/go.mod h1:1i71OnUq3iUe1ma7Lr6yG6/rjvM3emb6yoL7xLFzcVQ=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/imdario/mergo v0.3.7 h1:Y+UAYTZ7gDEuOfhxKWy+dvb5dRQ6rJjFSdX2HZY1/gI=
github.com/imdario/mergo v0.3.7/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.12 h1:WMhc1ik4LNkTg8U9l3hI1LvxKmIL+f1+WV/SZtCbDDA=
github.com/miekg/dns v1.1.12/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/nbrownus/go-metrics-prometheus v0.0.0-20180622211546-6e6d5173d99c h1:G/mfx/MWYuaaGlHkZQBBXFAJiYnRt/GaOVxnRHjlxg4=
github.com/nbrownus/go-metrics-prometheus v0.0.0-20180622211546-6e6d5173d99c/go.mod h1:1yMri853KAI2pPAUnESjaqZj9JeImOUM+6A4GuuPmTs=
github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v0.9.3 h1:9iH4JKXLzFbOAdtqv/a+j8aewx2Y8lAjAydhbaScPF8=
github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 h1:S/YWwWx/RA8rT8tKFRuGUZhuA90OyIBpPCXkcbwU8DE=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/common v0.4.1 h1:K0MGApIoQvMw27RTdJkPbr3JZ7DNbtxQNyi5STVM6Kw=
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.0.0-20190523193104-a7aeb8df3389 h1:F/k2nob1S9M6v5Xkq7KjSTQirOYaYQord0jR4TwyVmY=
github.com/prometheus/procfs v0.0.0-20190523193104-a7aeb8df3389/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a h1:9ZKAASQSHhDYGoxY8uLVpewe1GDZ2vu2Tr/vTdVAkFQ=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/songgao/water v0.0.0-20190402020555-6ad6edefb15c h1:VZsL/fAl8XnHj5Zn+cRvLcFbMHmCj7tdPrkKZSRziJ0=
github.com/songgao/water v0.0.0-20190402020555-6ad6edefb15c/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao=
github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc h1:R83G5ikgLMxrBvLh22JhdfI8K6YXEPHx5P03Uu3DRs4=
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f h1:R423Cnkcp5JABoeemiGEPlt9tHXFfw5kvc0yqlxRPWo=
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190522155817-f3200d17e092 h1:4QSRKanuywn15aTZvI/mIDEgPQpswuFndXpOj3rKEco=
golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190524152521-dbbf3f1254d4 h1:VSJ45BzqrVgR4clSx415y1rHH7QAGhGt71J0ZmhLYrc=
golang.org/x/sys v0.0.0-20190524152521-dbbf3f1254d4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

82
handshake.go Normal file
View File

@ -0,0 +1,82 @@
package nebula
import (
"crypto/hmac"
"crypto/sha256"
"errors"
"github.com/golang/protobuf/proto"
)
const (
handshakeIXPSK0 = 0
handshakeXXPSK0 = 1
)
func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
//TODO: For stage 1 we won't have hostinfo yet but stage 2 and above would require it, this check may be helpful in those cases
//if err != nil {
// l.WithError(err).WithField("udpAddr", addr).Error("Error while finding host info for handshake message")
// return
//}
tearDown := false
switch h.Subtype {
case handshakeIXPSK0:
switch h.MessageCounter {
case 1:
tearDown = ixHandshakeStage1(f, addr, newHostinfo, packet, h)
case 2:
tearDown = ixHandshakeStage2(f, addr, newHostinfo, packet, h)
}
}
if tearDown && newHostinfo != nil {
f.handshakeManager.DeleteIndex(newHostinfo.localIndexId)
f.handshakeManager.DeleteVpnIP(newHostinfo.hostId)
}
}
func HandshakeBytesWithMAC(details *NebulaHandshakeDetails, key []byte) ([]byte, error) {
mac := hmac.New(sha256.New, key)
b, err := proto.Marshal(details)
if err != nil {
return nil, errors.New("Unable to marshal nebula handshake")
}
mac.Write(b)
sum := mac.Sum(nil)
hs := &NebulaHandshake{
Details: details,
Hmac: sum,
}
hsBytes, err := proto.Marshal(hs)
if err != nil {
l.Debugln("failed to generate NebulaHandshake protobuf", err)
}
return hsBytes, nil
}
func (hs *NebulaHandshake) CheckHandshakeMAC(keys [][]byte) bool {
b, err := proto.Marshal(hs.Details)
if err != nil {
return false
}
for _, k := range keys {
mac := hmac.New(sha256.New, k)
mac.Write(b)
expectedMAC := mac.Sum(nil)
if hmac.Equal(hs.Hmac, expectedMAC) {
return true
}
}
//l.Debugln(hs.Hmac, expectedMAC)
return false
}

356
handshake_ix.go Normal file
View File

@ -0,0 +1,356 @@
package nebula
import (
"sync/atomic"
"time"
"bytes"
"github.com/flynn/noise"
"github.com/golang/protobuf/proto"
)
// NOISE IX Handshakes
// This function constructs a handshake packet, but does not actually send it
// Sending is done by the handshake manager
func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
// This queries the lighthouse if we don't know a remote for the host
if hostinfo.remote == nil {
ips, err := f.lightHouse.Query(vpnIp, f)
if err != nil {
//l.Debugln(err)
}
for _, ip := range ips {
hostinfo.AddRemote(ip)
}
}
myIndex, err := generateIndex()
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return
}
ci := hostinfo.ConnectionState
f.handshakeManager.AddIndexHostInfo(myIndex, hostinfo)
hsProto := &NebulaHandshakeDetails{
InitiatorIndex: myIndex,
Time: uint64(time.Now().Unix()),
Cert: ci.certState.rawCertificateNoKey,
}
hs := &NebulaHandshake{
Details: hsProto,
Hmac: nil,
}
hsBytes, err := proto.Marshal(hs)
//hsBytes, err := HandshakeBytesWithMAC(hsProto, f.handshakeMACKey)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return
}
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, 0, 1)
atomic.AddUint64(ci.messageCounter, 1)
msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
}
hostinfo.HandshakePacket[0] = msg
hostinfo.HandshakeReady = true
hostinfo.handshakeStart = time.Now()
/*
l.Debugln("ZZZZZZZZZZREMOTE: ", hostinfo.remote)
*/
}
func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool {
var ip uint32
if h.RemoteIndex == 0 {
ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
if err != nil {
l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
return true
}
hs := &NebulaHandshake{}
err = proto.Unmarshal(msg, hs)
/*
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
*/
if err != nil || hs.Details == nil {
l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
return true
}
hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex)
if hostinfo != nil && bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) {
if msg, ok := hostinfo.HandshakePacket[2]; ok {
err := f.outside.WriteTo(msg, addr)
if err != nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
} else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
}
return false
}
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cached", true).
WithField("packets", hostinfo.HandshakePacket).
Error("Seen this handshake packet already but don't have a cached packet to return")
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
if err != nil {
l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Invalid certificate from host")
return true
}
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
myIndex, err := generateIndex()
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
return true
}
hostinfo, err = f.handshakeManager.AddIndex(myIndex, ci)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Error adding index to connection manager")
return true
}
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message received")
hostinfo.remoteIndexId = hs.Details.InitiatorIndex
hs.Details.ResponderIndex = myIndex
hs.Details.Cert = ci.certState.rawCertificateNoKey
hs.Hmac = nil
hsBytes, err := proto.Marshal(hs)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return true
}
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return true
}
if f.hostMap.CheckHandshakeCompleteIP(vpnIP) && vpnIP < ip2int(f.certState.certificate.Details.Ips[0].IP) {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Prevented a handshake race")
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return true
}
hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:]))
copy(hostinfo.HandshakePacket[0], packet[HeaderLen:])
// Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection.
if dKey != nil && eKey != nil {
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
copy(hostinfo.HandshakePacket[2], msg)
err := f.outside.WriteTo(msg, addr)
if err != nil {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
} else {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent")
}
ip = ip2int(remoteCert.Details.Ips[0].IP)
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
//l.Debugln("got symmetric pairs")
//hostinfo.ClearRemotes()
hostinfo.AddRemote(*addr)
f.lightHouse.AddRemoteAndReset(ip, addr)
if f.serveDns {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
}
ho, err := f.hostMap.QueryVpnIP(vpnIP)
if err == nil && ho.localIndexId != 0 {
l.WithField("vpnIp", vpnIP).
WithField("action", "removing stale index").
WithField("index", ho.localIndexId).
Debug("Handshake processing")
f.hostMap.DeleteIndex(ho.localIndexId)
}
f.hostMap.AddIndexHostInfo(hostinfo.localIndexId, hostinfo)
f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo)
hostinfo.handshakeComplete()
} else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
return true
}
}
f.hostMap.AddRemote(ip, addr)
/*
l.Debugln("111 ZZZZZZZZZZADDR: ", addr)
l.Debugln("111 ZZZZZZZZZZREMOTE: ", hostinfo.remote)
l.Debugln("111 ZZZZZZZZZZREMOTEs: ", hostinfo.Remotes[0].addr)
*/
return false
}
func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool {
if hostinfo == nil {
return true
}
if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Already seen this handshake packet")
return false
}
ci := hostinfo.ConnectionState
// Mark packet 2 as seen so it doesn't show up as missed
ci.window.Update(2)
hostinfo.HandshakePacket[2] = make([]byte, len(packet[HeaderLen:]))
copy(hostinfo.HandshakePacket[2], packet[HeaderLen:])
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage")
// We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
// near future
return false
}
hs := &NebulaHandshake{}
err = proto.Unmarshal(msg, hs)
if err != nil || hs.Details == nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
return true
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Invalid certificate from host")
return true
}
vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration).
Info("Handshake message received")
//ci.remoteIndex = hs.ResponderIndex
hostinfo.remoteIndexId = hs.Details.ResponderIndex
hs.Details.Cert = ci.certState.rawCertificateNoKey
/*
hsBytes, err := proto.Marshal(hs)
if err != nil {
l.Debugln("Failed to marshal handshake: ", err)
return
}
*/
// Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection.
if dKey != nil && eKey != nil {
ip := ip2int(remoteCert.Details.Ips[0].IP)
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
//l.Debugln("got symmetric pairs")
//hostinfo.ClearRemotes()
f.hostMap.AddRemote(ip, addr)
f.lightHouse.AddRemoteAndReset(ip, addr)
if f.serveDns {
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
}
ho, err := f.hostMap.QueryVpnIP(vpnIP)
if err == nil && ho.localIndexId != 0 {
l.WithField("vpnIp", vpnIP).
WithField("action", "removing stale index").
WithField("index", ho.localIndexId).
Debug("Handshake processing")
f.hostMap.DeleteIndex(ho.localIndexId)
}
f.hostMap.AddVpnIPHostInfo(vpnIP, hostinfo)
f.hostMap.AddIndexHostInfo(hostinfo.localIndexId, hostinfo)
hostinfo.handshakeComplete()
f.metricHandshakes.Update(duration)
} else {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
return true
}
return false
/*
l.Debugln("222 ZZZZZZZZZZREMOTE: ", hostinfo.remote)
*/
}

200
handshake_manager.go Normal file
View File

@ -0,0 +1,200 @@
package nebula
import (
"crypto/rand"
"encoding/binary"
"fmt"
"net"
"time"
"github.com/sirupsen/logrus"
)
const (
// Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries
// With 100ms interval and 20 retries is 23.5 seconds
HandshakeTryInterval = time.Millisecond * 100
HandshakeRetries = 20
// HandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
HandshakeWaitRotation = 5
)
type HandshakeManager struct {
pendingHostMap *HostMap
mainHostMap *HostMap
lightHouse *LightHouse
outside *udpConn
OutboundHandshakeTimer *SystemTimerWheel
InboundHandshakeTimer *SystemTimerWheel
}
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn) *HandshakeManager {
return &HandshakeManager{
pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges),
mainHostMap: mainHostMap,
lightHouse: lightHouse,
outside: outside,
OutboundHandshakeTimer: NewSystemTimerWheel(HandshakeTryInterval, HandshakeTryInterval*HandshakeRetries),
InboundHandshakeTimer: NewSystemTimerWheel(HandshakeTryInterval, HandshakeTryInterval*HandshakeRetries),
}
}
func (c *HandshakeManager) Run(f EncWriter) {
clockSource := time.Tick(HandshakeTryInterval)
for now := range clockSource {
c.NextOutboundHandshakeTimerTick(now, f)
c.NextInboundHandshakeTimerTick(now)
}
}
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
c.OutboundHandshakeTimer.advance(now)
for {
ep := c.OutboundHandshakeTimer.Purge()
if ep == nil {
break
}
vpnIP := ep.(uint32)
index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP)
if err != nil {
continue
}
hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
if err != nil {
continue
}
// If we haven't finished the handshake and we haven't hit max retries, query
// lighthouse and then send the handshake packet again.
if hostinfo.HandshakeCounter < HandshakeRetries && !hostinfo.HandshakeComplete {
if hostinfo.remote == nil {
// We continue to query the lighthouse because hosts may
// come online during handshake retries. If the query
// succeeds (no error), add the lighthouse info to hostinfo
ips, err := c.lightHouse.Query(vpnIP, f)
if err == nil {
for _, ip := range ips {
hostinfo.AddRemote(ip)
}
hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
}
}
hostinfo.HandshakeCounter++
// We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
// all the others until we can stand up a connection.
if hostinfo.HandshakeCounter > HandshakeWaitRotation {
hostinfo.rotateRemote()
}
// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
if hostinfo.HandshakeReady && hostinfo.remote != nil {
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
if err != nil {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake message")
} else {
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
// keep the real packet struct around for logging purposes
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message sent")
}
}
// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
//l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
c.OutboundHandshakeTimer.Add(vpnIP, HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
} else {
c.pendingHostMap.DeleteVpnIP(vpnIP)
c.pendingHostMap.DeleteIndex(index)
}
}
}
func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) {
c.InboundHandshakeTimer.advance(now)
for {
ep := c.InboundHandshakeTimer.Purge()
if ep == nil {
break
}
index := ep.(uint32)
vpnIP, err := c.pendingHostMap.GetVpnIPByIndex(index)
if err != nil {
continue
}
c.pendingHostMap.DeleteIndex(index)
c.pendingHostMap.DeleteVpnIP(vpnIP)
}
}
func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
// We lock here and use an array to insert items to prevent locking the
// main receive thread for very long by waiting to add items to the pending map
c.OutboundHandshakeTimer.Add(vpnIP, HandshakeTryInterval)
return hostinfo
}
func (c *HandshakeManager) DeleteVpnIP(vpnIP uint32) {
//l.Debugln("Deleting pending vpn ip :", IntIp(vpnIP))
c.pendingHostMap.DeleteVpnIP(vpnIP)
}
func (c *HandshakeManager) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) {
hostinfo, err := c.pendingHostMap.AddIndex(index, ci)
if err != nil {
return nil, fmt.Errorf("Issue adding index: %d", index)
}
//c.mainHostMap.AddIndexHostInfo(index, hostinfo)
c.InboundHandshakeTimer.Add(index, time.Second*10)
return hostinfo, nil
}
func (c *HandshakeManager) AddIndexHostInfo(index uint32, h *HostInfo) {
c.pendingHostMap.AddIndexHostInfo(index, h)
}
func (c *HandshakeManager) DeleteIndex(index uint32) {
//l.Debugln("Deleting pending index :", index)
c.pendingHostMap.DeleteIndex(index)
}
func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) {
return c.pendingHostMap.QueryIndex(index)
}
func (c *HandshakeManager) EmitStats() {
c.pendingHostMap.EmitStats("pending")
c.mainHostMap.EmitStats("main")
}
// Utility functions below
func generateIndex() (uint32, error) {
b := make([]byte, 4)
_, err := rand.Read(b)
if err != nil {
l.Errorln(err)
return 0, err
}
index := binary.BigEndian.Uint32(b)
if l.Level >= logrus.DebugLevel {
l.WithField("index", index).
Debug("Generated index")
}
return index, nil
}

191
handshake_manager_test.go Normal file
View File

@ -0,0 +1,191 @@
package nebula
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var indexes []uint32 = []uint32{1000, 2000, 3000, 4000}
//var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
var ips []uint32 = []uint32{9000}
func Test_NewHandshakeManagerIndex(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{})
now := time.Now()
blah.NextInboundHandshakeTimerTick(now)
// Add four indexes
for _, v := range indexes {
blah.AddIndex(v, &ConnectionState{})
}
// Confirm they are in the pending index list
for _, v := range indexes {
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v))
}
// Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Indexes, 0)
// Jump ahead 8 seconds
for i := 1; i <= HandshakeRetries; i++ {
next_tick := now.Add(HandshakeTryInterval * time.Duration(i))
blah.NextInboundHandshakeTimerTick(next_tick)
}
// Confirm they are still in the pending index list
for _, v := range indexes {
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v))
}
// Jump ahead 4 more seconds
next_tick := now.Add(12 * time.Second)
blah.NextInboundHandshakeTimerTick(next_tick)
// Confirm they have been removed
for _, v := range indexes {
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(v))
}
}
func Test_NewHandshakeManagerVpnIP(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{})
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
// Add four "IPs" - which are just uint32s
for _, v := range ips {
blah.AddVpnIP(v)
}
// Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Hosts, 0)
// Confirm they are in the pending index list
for _, v := range ips {
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
}
// Jump ahead `HandshakeRetries` ticks
cumulative := time.Duration(0)
for i := 0; i <= HandshakeRetries+1; i++ {
cumulative += time.Duration(i)*HandshakeTryInterval + 1
next_tick := now.Add(cumulative)
//l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
}
// Confirm they are still in the pending index list
for _, v := range ips {
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
}
// Jump ahead 1 more second
cumulative += time.Duration(HandshakeRetries+1) * HandshakeTryInterval
next_tick := now.Add(cumulative)
//l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
// Confirm they have been removed
for _, v := range ips {
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(v))
}
}
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{})
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
hostinfo := blah.AddVpnIP(101010)
// Pretned we have an index too
blah.AddIndexHostInfo(12341234, hostinfo)
assert.Contains(t, blah.pendingHostMap.Indexes, uint32(12341234))
// Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending
// but not main hostmap
cumulative := time.Duration(0)
for i := 1; i <= HandshakeRetries+2; i++ {
cumulative += HandshakeTryInterval * time.Duration(i)
next_tick := now.Add(cumulative)
blah.NextOutboundHandshakeTimerTick(next_tick, mw)
}
/*
for i := 0; i <= HandshakeRetries+1; i++ {
next_tick := now.Add(cumulative)
//l.Infoln(next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick)
}
*/
/*
for i := 0; i <= HandshakeRetries+1; i++ {
next_tick := now.Add(time.Duration(i) * time.Second)
blah.NextOutboundHandshakeTimerTick(next_tick)
}
*/
/*
cumulative += HandshakeTryInterval*time.Duration(HandshakeRetries) + 3
next_tick := now.Add(cumulative)
l.Infoln(cumulative, next_tick)
blah.NextOutboundHandshakeTimerTick(next_tick)
*/
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
}
func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{})
now := time.Now()
blah.NextInboundHandshakeTimerTick(now)
hostinfo, _ := blah.AddIndex(12341234, &ConnectionState{})
// Pretned we have an index too
blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo)
assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010))
for i := 1; i <= HandshakeRetries+2; i++ {
next_tick := now.Add(HandshakeTryInterval * time.Duration(i))
blah.NextInboundHandshakeTimerTick(next_tick)
}
next_tick := now.Add(HandshakeTryInterval*HandshakeRetries + 3)
blah.NextInboundHandshakeTimerTick(next_tick)
assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
}
type mockEncWriter struct {
}
func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
return
}
func (mw *mockEncWriter) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
return
}

188
header.go Normal file
View File

@ -0,0 +1,188 @@
package nebula
import (
"encoding/binary"
"encoding/json"
"errors"
"fmt"
)
//Version 1 header:
// 0 31
// |-----------------------------------------------------------------------|
// | Version (uint4) | Type (uint4) | Subtype (uint8) | Reserved (uint16) | 32
// |-----------------------------------------------------------------------|
// | Remote index (uint32) | 64
// |-----------------------------------------------------------------------|
// | Message counter | 96
// | (uint64) | 128
// |-----------------------------------------------------------------------|
// | payload... |
const (
Version uint8 = 1
HeaderLen = 16
)
type NebulaMessageType uint8
type NebulaMessageSubType uint8
const (
handshake NebulaMessageType = 0
message NebulaMessageType = 1
recvError NebulaMessageType = 2
lightHouse NebulaMessageType = 3
test NebulaMessageType = 4
closeTunnel NebulaMessageType = 5
//TODO These are deprecated as of 06/12/2018 - NB
testRemote NebulaMessageType = 6
testRemoteReply NebulaMessageType = 7
)
var typeMap = map[NebulaMessageType]string{
handshake: "handshake",
message: "message",
recvError: "recvError",
lightHouse: "lightHouse",
test: "test",
closeTunnel: "closeTunnel",
//TODO These are deprecated as of 06/12/2018 - NB
testRemote: "testRemote",
testRemoteReply: "testRemoteReply",
}
const (
testRequest NebulaMessageSubType = 0
testReply NebulaMessageSubType = 1
)
var eHeaderTooShort = errors.New("header is too short")
var subTypeTestMap = map[NebulaMessageSubType]string{
testRequest: "testRequest",
testReply: "testReply",
}
var subTypeNoneMap = map[NebulaMessageSubType]string{0: "none"}
var subTypeMap = map[NebulaMessageType]*map[NebulaMessageSubType]string{
message: &subTypeNoneMap,
recvError: &subTypeNoneMap,
lightHouse: &subTypeNoneMap,
test: &subTypeTestMap,
closeTunnel: &subTypeNoneMap,
handshake: {
handshakeIXPSK0: "ix_psk0",
},
//TODO: these are deprecated
testRemote: &subTypeNoneMap,
testRemoteReply: &subTypeNoneMap,
}
type Header struct {
Version uint8
Type NebulaMessageType
Subtype NebulaMessageSubType
Reserved uint16
RemoteIndex uint32
MessageCounter uint64
}
// HeaderEncode uses the provided byte array to encode the provided header values into.
// Byte array must be capped higher than HeaderLen or this will panic
func HeaderEncode(b []byte, v uint8, t uint8, st uint8, ri uint32, c uint64) []byte {
b = b[:HeaderLen]
b[0] = byte(v<<4 | (t & 0x0f))
b[1] = byte(st)
binary.BigEndian.PutUint16(b[2:4], 0)
binary.BigEndian.PutUint32(b[4:8], ri)
binary.BigEndian.PutUint64(b[8:16], c)
return b
}
// String creates a readable string representation of a header
func (h *Header) String() string {
if h == nil {
return "<nil>"
}
return fmt.Sprintf("ver=%d type=%s subtype=%s reserved=%#x remoteindex=%v messagecounter=%v",
h.Version, h.TypeName(), h.SubTypeName(), h.Reserved, h.RemoteIndex, h.MessageCounter)
}
// MarshalJSON creates a json string representation of a header
func (h *Header) MarshalJSON() ([]byte, error) {
return json.Marshal(m{
"version": h.Version,
"type": h.TypeName(),
"subType": h.SubTypeName(),
"reserved": h.Reserved,
"remoteIndex": h.RemoteIndex,
"messageCounter": h.MessageCounter,
})
}
// Encode turns header into bytes
func (h *Header) Encode(b []byte) ([]byte, error) {
if h == nil {
return nil, errors.New("nil header")
}
return HeaderEncode(b, h.Version, uint8(h.Type), uint8(h.Subtype), h.RemoteIndex, h.MessageCounter), nil
}
// Parse is a helper function to parses given bytes into new Header struct
func (h *Header) Parse(b []byte) error {
if len(b) < HeaderLen {
return eHeaderTooShort
}
// get upper 4 bytes
h.Version = uint8((b[0] >> 4) & 0x0f)
// get lower 4 bytes
h.Type = NebulaMessageType(b[0] & 0x0f)
h.Subtype = NebulaMessageSubType(b[1])
h.Reserved = binary.BigEndian.Uint16(b[2:4])
h.RemoteIndex = binary.BigEndian.Uint32(b[4:8])
h.MessageCounter = binary.BigEndian.Uint64(b[8:16])
return nil
}
// TypeName will transform the headers message type into a human string
func (h *Header) TypeName() string {
return TypeName(h.Type)
}
// TypeName will transform a nebula message type into a human string
func TypeName(t NebulaMessageType) string {
if n, ok := typeMap[t]; ok {
return n
}
return "unknown"
}
// SubTypeName will transform the headers message sub type into a human string
func (h *Header) SubTypeName() string {
return SubTypeName(h.Type, h.Subtype)
}
// SubTypeName will transform a nebula message sub type into a human string
func SubTypeName(t NebulaMessageType, s NebulaMessageSubType) string {
if n, ok := subTypeMap[t]; ok {
if x, ok := (*n)[s]; ok {
return x
}
}
return "unknown"
}
// NewHeader turns bytes into a header
func NewHeader(b []byte) (*Header, error) {
h := new(Header)
if err := h.Parse(b); err != nil {
return nil, err
}
return h, nil
}

118
header_test.go Normal file
View File

@ -0,0 +1,118 @@
package nebula
import (
"github.com/stretchr/testify/assert"
"reflect"
"testing"
)
type headerTest struct {
expectedBytes []byte
*Header
}
// 0001 0010 00010010
var headerBigEndianTests = []headerTest{{
expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9},
// 1010 0000
Header: &Header{
// 1111 1+2+4+8 = 15
Version: 5,
Type: 4,
Subtype: 0,
Reserved: 0,
RemoteIndex: 10,
MessageCounter: 9,
},
},
}
func TestEncode(t *testing.T) {
for _, tt := range headerBigEndianTests {
b, err := tt.Encode(make([]byte, HeaderLen))
if err != nil {
t.Fatal(err)
}
assert.Equal(t, tt.expectedBytes, b)
}
}
func TestParse(t *testing.T) {
for _, tt := range headerBigEndianTests {
b := tt.expectedBytes
parsedHeader := &Header{}
parsedHeader.Parse(b)
if !reflect.DeepEqual(tt.Header, parsedHeader) {
t.Fatalf("got %#v; want %#v", parsedHeader, tt.Header)
}
}
}
func TestTypeName(t *testing.T) {
assert.Equal(t, "test", TypeName(test))
assert.Equal(t, "test", (&Header{Type: test}).TypeName())
assert.Equal(t, "unknown", TypeName(99))
assert.Equal(t, "unknown", (&Header{Type: 99}).TypeName())
}
func TestSubTypeName(t *testing.T) {
assert.Equal(t, "testRequest", SubTypeName(test, testRequest))
assert.Equal(t, "testRequest", (&Header{Type: test, Subtype: testRequest}).SubTypeName())
assert.Equal(t, "unknown", SubTypeName(99, testRequest))
assert.Equal(t, "unknown", (&Header{Type: 99, Subtype: testRequest}).SubTypeName())
assert.Equal(t, "unknown", SubTypeName(test, 99))
assert.Equal(t, "unknown", (&Header{Type: test, Subtype: 99}).SubTypeName())
assert.Equal(t, "none", SubTypeName(message, 0))
assert.Equal(t, "none", (&Header{Type: message, Subtype: 0}).SubTypeName())
}
func TestTypeMap(t *testing.T) {
// Force people to document this stuff
assert.Equal(t, map[NebulaMessageType]string{
handshake: "handshake",
message: "message",
recvError: "recvError",
lightHouse: "lightHouse",
test: "test",
closeTunnel: "closeTunnel",
testRemote: "testRemote",
testRemoteReply: "testRemoteReply",
}, typeMap)
assert.Equal(t, map[NebulaMessageType]*map[NebulaMessageSubType]string{
message: &subTypeNoneMap,
recvError: &subTypeNoneMap,
lightHouse: &subTypeNoneMap,
test: &subTypeTestMap,
closeTunnel: &subTypeNoneMap,
handshake: {
handshakeIXPSK0: "ix_psk0",
},
testRemote: &subTypeNoneMap,
testRemoteReply: &subTypeNoneMap,
}, subTypeMap)
}
func TestHeader_String(t *testing.T) {
assert.Equal(
t,
"ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97",
(&Header{100, test, testRequest, 99, 98, 97}).String(),
)
}
func TestHeader_MarshalJSON(t *testing.T) {
b, err := (&Header{100, test, testRequest, 99, 98, 97}).MarshalJSON()
assert.Nil(t, err)
assert.Equal(
t,
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
string(b),
)
}

743
hostmap.go Normal file
View File

@ -0,0 +1,743 @@
package nebula
import (
"encoding/json"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
)
//const ProbeLen = 100
const PromoteEvery = 1000
const MaxRemotes = 10
// How long we should prevent roaming back to the previous IP.
// This helps prevent flapping due to packets already in flight
const RoamingSupressSeconds = 2
type HostMap struct {
sync.RWMutex //Because we concurrently read and write to our maps
name string
Indexes map[uint32]*HostInfo
Hosts map[uint32]*HostInfo
preferredRanges []*net.IPNet
vpnCIDR *net.IPNet
defaultRoute uint32
}
type HostInfo struct {
remote *udpAddr
Remotes []*HostInfoDest
promoteCounter uint32
ConnectionState *ConnectionState
handshakeStart time.Time
HandshakeReady bool
HandshakeCounter int
HandshakeComplete bool
HandshakePacket map[uint8][]byte
packetStore []*cachedPacket
remoteIndexId uint32
localIndexId uint32
hostId uint32
recvError int
lastRoam time.Time
lastRoamRemote *udpAddr
}
type cachedPacket struct {
messageType NebulaMessageType
messageSubType NebulaMessageSubType
callback packetCallback
packet []byte
}
type packetCallback func(t NebulaMessageType, st NebulaMessageSubType, h *HostInfo, p, nb, out []byte)
type HostInfoDest struct {
active bool
addr *udpAddr
//probes [ProbeLen]bool
probeCounter int
}
type Probe struct {
Addr *net.UDPAddr
Counter int
}
func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
h := map[uint32]*HostInfo{}
i := map[uint32]*HostInfo{}
m := HostMap{
name: name,
Indexes: i,
Hosts: h,
preferredRanges: preferredRanges,
vpnCIDR: vpnCIDR,
defaultRoute: 0,
}
return &m
}
// UpdateStats takes a name and reports host and index counts to the stats collection system
func (hm *HostMap) EmitStats(name string) {
hm.RLock()
hostLen := len(hm.Hosts)
indexLen := len(hm.Indexes)
hm.RUnlock()
metrics.GetOrRegisterGauge("hostmap."+name+".hosts", nil).Update(int64(hostLen))
metrics.GetOrRegisterGauge("hostmap."+name+".indexes", nil).Update(int64(indexLen))
}
func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
hm.RLock()
if i, ok := hm.Hosts[vpnIP]; ok {
index := i.localIndexId
hm.RUnlock()
return index, nil
}
hm.RUnlock()
return 0, errors.New("vpn IP not found")
}
func (hm *HostMap) GetVpnIPByIndex(index uint32) (uint32, error) {
hm.RLock()
if i, ok := hm.Indexes[index]; ok {
vpnIP := i.hostId
hm.RUnlock()
return vpnIP, nil
}
hm.RUnlock()
return 0, errors.New("vpn IP not found")
}
func (hm *HostMap) Add(ip uint32, hostinfo *HostInfo) {
hm.Lock()
hm.Hosts[ip] = hostinfo
hm.Unlock()
}
func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo {
h := &HostInfo{}
hm.RLock()
if _, ok := hm.Hosts[vpnIP]; !ok {
hm.RUnlock()
h = &HostInfo{
Remotes: []*HostInfoDest{},
promoteCounter: 0,
hostId: vpnIP,
HandshakePacket: make(map[uint8][]byte, 0),
}
hm.Lock()
hm.Hosts[vpnIP] = h
hm.Unlock()
return h
} else {
h = hm.Hosts[vpnIP]
hm.RUnlock()
return h
}
}
func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
hm.Lock()
delete(hm.Hosts, vpnIP)
if len(hm.Hosts) == 0 {
hm.Hosts = map[uint32]*HostInfo{}
}
hm.Unlock()
if l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap vpnIp deleted")
}
}
func (hm *HostMap) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) {
hm.Lock()
if _, ok := hm.Indexes[index]; !ok {
h := &HostInfo{
ConnectionState: ci,
Remotes: []*HostInfoDest{},
localIndexId: index,
HandshakePacket: make(map[uint8][]byte, 0),
}
hm.Indexes[index] = h
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": false, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap index added")
hm.Unlock()
return h, nil
}
hm.Unlock()
return nil, fmt.Errorf("refusing to overwrite existing index: %d", index)
}
func (hm *HostMap) AddIndexHostInfo(index uint32, h *HostInfo) {
hm.Lock()
h.localIndexId = index
hm.Indexes[index] = h
hm.Unlock()
if l.Level > logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap index added")
}
}
func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
hm.Lock()
h.hostId = vpnIP
hm.Hosts[vpnIP] = h
hm.Unlock()
if l.Level > logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap vpnIp added")
}
}
func (hm *HostMap) DeleteIndex(index uint32) {
hm.Lock()
delete(hm.Indexes, index)
if len(hm.Indexes) == 0 {
hm.Indexes = map[uint32]*HostInfo{}
}
hm.Unlock()
if l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
Debug("Hostmap index deleted")
}
}
func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) {
//TODO: we probably just want ot return bool instead of error, or at least a static error
hm.RLock()
if h, ok := hm.Indexes[index]; ok {
hm.RUnlock()
return h, nil
} else {
hm.RUnlock()
return nil, errors.New("unable to find index")
}
}
// This function needs to range because we don't keep a map of remote indexes.
func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
hm.RLock()
for _, h := range hm.Indexes {
if h.ConnectionState != nil && h.remoteIndexId == index {
hm.RUnlock()
return h, nil
}
}
for _, h := range hm.Hosts {
if h.ConnectionState != nil && h.remoteIndexId == index {
hm.RUnlock()
return h, nil
}
}
hm.RUnlock()
return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name)
}
func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo {
hm.Lock()
i, v := hm.Hosts[vpnIp]
if v {
i.AddRemote(*remote)
} else {
i = &HostInfo{
Remotes: []*HostInfoDest{NewHostInfoDest(remote)},
promoteCounter: 0,
hostId: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
}
i.remote = i.Remotes[0].addr
hm.Hosts[vpnIp] = i
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap remote ip added")
}
i.ForcePromoteBest(hm.preferredRanges)
hm.Unlock()
return i
}
func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) {
return hm.queryVpnIP(vpnIp, nil)
}
// PromoteBestQueryVpnIP will attempt to lazily switch to the best remote every
// `PromoteEvery` calls to this function for a given host.
func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostInfo, error) {
return hm.queryVpnIP(vpnIp, ifce)
}
func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
if hm.vpnCIDR.Contains(int2ip(vpnIp)) == false && hm.defaultRoute != 0 {
// FIXME: this shouldn't ship
d := hm.Hosts[hm.defaultRoute]
if d != nil {
return hm.Hosts[hm.defaultRoute], nil
}
}
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok {
if promoteIfce != nil {
h.TryPromoteBest(hm.preferredRanges, promoteIfce)
}
//fmt.Println(h.remote)
hm.RUnlock()
return h, nil
} else {
//return &net.UDPAddr{}, nil, errors.New("Unable to find host")
hm.RUnlock()
/*
if lightHouse != nil {
lightHouse.Query(vpnIp)
return nil, errors.New("Unable to find host")
}
*/
return nil, errors.New("unable to find host")
}
}
func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool {
hm.RLock()
if i, ok := hm.Hosts[vpnIP]; ok {
if i == nil {
hm.RUnlock()
return false
}
complete := i.HandshakeComplete
hm.RUnlock()
return complete
}
hm.RUnlock()
return false
}
func (hm *HostMap) CheckHandshakeCompleteIndex(index uint32) bool {
hm.RLock()
if i, ok := hm.Indexes[index]; ok {
if i == nil {
hm.RUnlock()
return false
}
complete := i.HandshakeComplete
hm.RUnlock()
return complete
}
hm.RUnlock()
return false
}
func (hm *HostMap) ClearRemotes(vpnIP uint32) {
hm.Lock()
i := hm.Hosts[vpnIP]
if i == nil {
hm.Unlock()
return
}
i.remote = nil
i.Remotes = nil
hm.Unlock()
}
func (hm *HostMap) SetDefaultRoute(ip uint32) {
hm.defaultRoute = ip
}
func (hm *HostMap) PunchList() []*udpAddr {
var list []*udpAddr
hm.RLock()
for _, v := range hm.Hosts {
for _, r := range v.Remotes {
list = append(list, r.addr)
}
// if h, ok := hm.Hosts[vpnIp]; ok {
// hm.Hosts[vpnIp].PromoteBest(hm.preferredRanges, false)
//fmt.Println(h.remote)
// }
}
hm.RUnlock()
return list
}
func (hm *HostMap) Punchy(conn *udpConn) {
for {
for _, addr := range hm.PunchList() {
conn.WriteTo([]byte{1}, addr)
}
time.Sleep(time.Second * 30)
}
}
func (i *HostInfo) MarshalJSON() ([]byte, error) {
return json.Marshal(m{
"remote": i.remote,
"remotes": i.Remotes,
"promote_counter": i.promoteCounter,
"connection_state": i.ConnectionState,
"handshake_start": i.handshakeStart,
"handshake_ready": i.HandshakeReady,
"handshake_counter": i.HandshakeCounter,
"handshake_complete": i.HandshakeComplete,
"handshake_packet": i.HandshakePacket,
"packet_store": i.packetStore,
"remote_index": i.remoteIndexId,
"local_index": i.localIndexId,
"host_id": int2ip(i.hostId),
"receive_errors": i.recvError,
"last_roam": i.lastRoam,
"last_roam_remote": i.lastRoamRemote,
})
}
func (i *HostInfo) BindConnectionState(cs *ConnectionState) {
i.ConnectionState = cs
}
func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
if i.remote == nil {
i.ForcePromoteBest(preferredRanges)
return
}
i.promoteCounter++
if i.promoteCounter%PromoteEvery == 0 {
// return early if we are already on a preferred remote
rIP := udp2ip(i.remote)
for _, l := range preferredRanges {
if l.Contains(rIP) {
return
}
}
// We re-query the lighthouse periodically while sending packets, so
// check for new remotes in our local lighthouse cache
ips := ifce.lightHouse.QueryCache(i.hostId)
for _, ip := range ips {
i.AddRemote(ip)
}
best, preferred := i.getBestRemote(preferredRanges)
if preferred && !best.Equals(i.remote) {
// Try to send a test packet to that host, this should
// cause it to detect a roaming event and switch remotes
ifce.send(test, testRequest, i.ConnectionState, i, best, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}
}
}
func (i *HostInfo) ForcePromoteBest(preferredRanges []*net.IPNet) {
best, _ := i.getBestRemote(preferredRanges)
if best != nil {
i.remote = best
}
}
func (i *HostInfo) getBestRemote(preferredRanges []*net.IPNet) (best *udpAddr, preferred bool) {
if len(i.Remotes) > 0 {
for _, r := range i.Remotes {
rIP := udp2ip(r.addr)
for _, l := range preferredRanges {
if l.Contains(rIP) {
return r.addr, true
}
}
if best == nil || !PrivateIP(rIP) {
best = r.addr
}
/*
for _, r := range i.Remotes {
// Must have > 80% probe success to be considered.
//fmt.Println("GRADE:", r.addr.IP, r.Grade())
if r.Grade() > float64(.8) {
if localToMe.Contains(r.addr.IP) == true {
best = r.addr
break
//i.remote = i.Remotes[c].addr
} else {
//}
}
*/
}
return best, false
}
return nil, false
}
// rotateRemote will move remote to the next ip in the list of remote ips for this host
// This is different than PromoteBest in that what is algorithmically best may not actually work.
// Only known use case is when sending a stage 0 handshake.
// It may be better to just send stage 0 handshakes to all known ips and sort it out in the receiver.
func (i *HostInfo) rotateRemote() {
// We have 0, can't rotate
if len(i.Remotes) < 1 {
return
}
if i.remote == nil {
i.remote = i.Remotes[0].addr
return
}
// We want to look at all but the very last entry since that is handled at the end
for x := 0; x < len(i.Remotes)-1; x++ {
// Find our current position and move to the next one in the list
if i.Remotes[x].addr.Equals(i.remote) {
i.remote = i.Remotes[x+1].addr
return
}
}
// Our current position was likely the last in the list, start over at 0
i.remote = i.Remotes[0].addr
}
func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
//TODO: return the error so we can log with more context
if len(i.packetStore) < 100 {
tempPacket := make([]byte, len(packet))
copy(tempPacket, packet)
//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
l.WithField("vpnIp", IntIp(i.hostId)).
WithField("length", len(i.packetStore)).
WithField("stored", true).
Debugf("Packet store")
} else if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(i.hostId)).
WithField("length", len(i.packetStore)).
WithField("stored", false).
Debugf("Packet store")
}
}
// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets
func (i *HostInfo) handshakeComplete() {
//TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because:
//TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send
//TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical
i.ConnectionState.queueLock.Lock()
i.HandshakeComplete = true
//TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen.
// Clamping it to 2 gets us out of the woods for now
*i.ConnectionState.messageCounter = 2
l.WithField("vpnIp", IntIp(i.hostId)).Debugf("Sending %d stored packets", len(i.packetStore))
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for _, cp := range i.packetStore {
cp.callback(cp.messageType, cp.messageSubType, i, cp.packet, nb, out)
}
i.packetStore = make([]*cachedPacket, 0)
i.ConnectionState.ready = true
i.ConnectionState.queueLock.Unlock()
i.ConnectionState.certState = nil
}
func (i *HostInfo) RemoteUDPAddrs() []*udpAddr {
var addrs []*udpAddr
for _, r := range i.Remotes {
addrs = append(addrs, r.addr)
}
return addrs
}
func (i *HostInfo) GetCert() *cert.NebulaCertificate {
if i.ConnectionState != nil {
return i.ConnectionState.peerCert
}
return nil
}
func (i *HostInfo) AddRemote(r udpAddr) *udpAddr {
remote := &r
//add := true
for _, r := range i.Remotes {
if r.addr.Equals(remote) {
return r.addr
//add = false
}
}
// Trim this down if necessary
if len(i.Remotes) > MaxRemotes {
i.Remotes = i.Remotes[len(i.Remotes)-MaxRemotes:]
}
i.Remotes = append(i.Remotes, NewHostInfoDest(remote))
return remote
//l.Debugf("Added remote %s for vpn ip", remote)
}
func (i *HostInfo) SetRemote(remote udpAddr) {
i.remote = i.AddRemote(remote)
}
func (i *HostInfo) ClearRemotes() {
i.remote = nil
i.Remotes = []*HostInfoDest{}
}
func (i *HostInfo) ClearConnectionState() {
i.ConnectionState = nil
}
func (i *HostInfo) RecvErrorExceeded() bool {
if i.recvError < 3 {
i.recvError += 1
return false
}
return true
}
//########################
func NewHostInfoDest(addr *udpAddr) *HostInfoDest {
i := &HostInfoDest{
addr: addr,
}
return i
}
func (hid *HostInfoDest) MarshalJSON() ([]byte, error) {
return json.Marshal(m{
"active": hid.active,
"address": hid.addr,
"probe_count": hid.probeCounter,
})
}
/*
func (hm *HostMap) DebugRemotes(vpnIp uint32) string {
s := "\n"
for _, h := range hm.Hosts {
for _, r := range h.Remotes {
s += fmt.Sprintf("%s : %d ## %v\n", r.addr.IP.String(), r.addr.Port, r.probes)
}
}
return s
}
func (d *HostInfoDest) Grade() float64 {
c1 := ProbeLen
for n := len(d.probes) - 1; n >= 0; n-- {
if d.probes[n] == true {
c1 -= 1
}
}
return float64(c1) / float64(ProbeLen)
}
func (d *HostInfoDest) Grade() (float64, float64, float64) {
c1 := ProbeLen
c2 := ProbeLen / 2
c2c := ProbeLen - ProbeLen/2
c3 := ProbeLen / 5
c3c := ProbeLen - ProbeLen/5
for n := len(d.probes) - 1; n >= 0; n-- {
if d.probes[n] == true {
c1 -= 1
if n >= c2c {
c2 -= 1
if n >= c3c {
c3 -= 1
}
}
}
//if n >= d {
}
return float64(c3) / float64(ProbeLen/5), float64(c2) / float64(ProbeLen/2), float64(c1) / float64(ProbeLen)
//return float64(c1) / float64(ProbeLen), float64(c2) / float64(ProbeLen/2), float64(c3) / float64(ProbeLen/5)
}
func (i *HostInfo) HandleReply(addr *net.UDPAddr, counter int) {
for _, r := range i.Remotes {
if r.addr.IP.Equal(addr.IP) && r.addr.Port == addr.Port {
r.ProbeReceived(counter)
}
}
}
func (i *HostInfo) Probes() []*Probe {
p := []*Probe{}
for _, d := range i.Remotes {
p = append(p, &Probe{Addr: d.addr, Counter: d.Probe()})
}
return p
}
func (d *HostInfoDest) Probe() int {
//d.probes = append(d.probes, true)
d.probeCounter++
d.probes[d.probeCounter%ProbeLen] = true
return d.probeCounter
//return d.probeCounter
}
func (d *HostInfoDest) ProbeReceived(probeCount int) {
if probeCount >= (d.probeCounter - ProbeLen) {
//fmt.Println("PROBE WORKED", probeCount)
//fmt.Println(d.addr, d.Grade())
d.probes[probeCount%ProbeLen] = false
}
}
*/
// Utility functions
func localIps() *[]net.IP {
//FIXME: This function is pretty garbage
var ips []net.IP
ifaces, _ := net.Interfaces()
for _, i := range ifaces {
addrs, _ := i.Addrs()
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
//continue
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip.To4() != nil && ip.IsLoopback() == false {
ips = append(ips, ip)
}
}
}
return &ips
}
func PrivateIP(ip net.IP) bool {
private := false
_, private24BitBlock, _ := net.ParseCIDR("10.0.0.0/8")
_, private20BitBlock, _ := net.ParseCIDR("172.16.0.0/12")
_, private16BitBlock, _ := net.ParseCIDR("192.168.0.0/16")
private = private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
return private
}

166
hostmap_test.go Normal file
View File

@ -0,0 +1,166 @@
package nebula
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
/*
func TestHostInfoDestProbe(t *testing.T) {
a, _ := net.ResolveUDPAddr("udp", "1.0.0.1:22222")
d := NewHostInfoDest(a)
// 999 probes that all return should give a 100% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
d.ProbeReceived(meh)
}
assert.Equal(t, d.Grade(), float64(1))
// 999 probes of which only half return should give a 50% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
if i%2 == 0 {
d.ProbeReceived(meh)
}
}
assert.Equal(t, d.Grade(), float64(.5))
// 999 probes of which none return should give a 0% success rate
for i := 0; i < 999; i++ {
d.Probe()
}
assert.Equal(t, d.Grade(), float64(0))
// 999 probes of which only 1/4 return should give a 25% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
if i%4 == 0 {
d.ProbeReceived(meh)
}
}
assert.Equal(t, d.Grade(), float64(.25))
// 999 probes of which only half return and are duplicates should give a 50% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
if i%2 == 0 {
d.ProbeReceived(meh)
d.ProbeReceived(meh)
}
}
assert.Equal(t, d.Grade(), float64(.5))
// 999 probes of which only way old replies return should give a 0% success rate
for i := 0; i < 999; i++ {
meh := d.Probe()
d.ProbeReceived(meh - 101)
}
assert.Equal(t, d.Grade(), float64(0))
}
*/
func TestHostmap(t *testing.T) {
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
myNets := []*net.IPNet{myNet}
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap("test", myNet, preferredRanges)
a := NewUDPAddrFromString("10.127.0.3:11111")
b := NewUDPAddrFromString("1.0.0.1:22222")
y := NewUDPAddrFromString("10.128.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a)
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), b)
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
info, _ := m.QueryVpnIP(ip2int(net.ParseIP("127.0.0.1")))
// There should be three remotes in the host map
assert.Equal(t, 3, len(info.Remotes))
// Adding an identical remote should not change the count
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
assert.Equal(t, 3, len(info.Remotes))
// Adding a fresh remote should add one
y = NewUDPAddrFromString("10.18.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
assert.Equal(t, 4, len(info.Remotes))
// Query and reference remote should get the first one (and not nil)
info, _ = m.QueryVpnIP(ip2int(net.ParseIP("127.0.0.1")))
assert.NotNil(t, info.remote)
// Promotion should ensure that the best remote is chosen (y)
info.ForcePromoteBest(myNets)
assert.True(t, myNet.Contains(udp2ip(info.remote)))
}
func TestHostmapdebug(t *testing.T) {
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap("test", myNet, preferredRanges)
a := NewUDPAddrFromString("10.127.0.3:11111")
b := NewUDPAddrFromString("1.0.0.1:22222")
y := NewUDPAddrFromString("10.128.0.3:11111")
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a)
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), b)
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
//t.Errorf("%s", m.DebugRemotes(1))
}
func TestHostMap_rotateRemote(t *testing.T) {
h := HostInfo{}
// 0 remotes, no panic
h.rotateRemote()
assert.Nil(t, h.remote)
// 1 remote, no panic
h.AddRemote(*NewUDPAddr(ip2int(net.IP{1, 1, 1, 1}), 0))
h.rotateRemote()
assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 1}))
h.AddRemote(*NewUDPAddr(ip2int(net.IP{1, 1, 1, 2}), 0))
h.AddRemote(*NewUDPAddr(ip2int(net.IP{1, 1, 1, 3}), 0))
h.AddRemote(*NewUDPAddr(ip2int(net.IP{1, 1, 1, 4}), 0))
// Rotate through those 3
h.rotateRemote()
assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 2}))
h.rotateRemote()
assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 3}))
h.rotateRemote()
assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 4}))
// Finally, we should start over
h.rotateRemote()
assert.Equal(t, udp2ipInt(h.remote), ip2int(net.IP{1, 1, 1, 1}))
}
func BenchmarkHostmappromote2(b *testing.B) {
for n := 0; n < b.N; n++ {
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap("test", myNet, preferredRanges)
y := NewUDPAddrFromString("10.128.0.3:11111")
a := NewUDPAddrFromString("10.127.0.3:11111")
g := NewUDPAddrFromString("1.0.0.1:22222")
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a)
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), g)
m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
}
b.Errorf("hi")
}

201
inside.go Normal file
View File

@ -0,0 +1,201 @@
package nebula
import (
"sync/atomic"
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
)
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte) {
err := newPacket(packet, false, fwPacket)
if err != nil {
l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
return
}
// Ignore local broadcast packets
if f.dropLocalBroadcast && fwPacket.RemoteIP == f.localBroadcast {
return
}
// Ignore broadcast packets
if f.dropMulticast && isMulticast(fwPacket.RemoteIP) {
return
}
hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
ci := hostinfo.ConnectionState
if ci.ready == false {
// Because we might be sending stored packets, lock here to stop new things going to
// the packet queue.
ci.queueLock.Lock()
if !ci.ready {
hostinfo.cachePacket(message, 0, packet, f.sendMessageNow)
ci.queueLock.Unlock()
return
}
ci.queueLock.Unlock()
}
if !f.firewall.Drop(packet, *fwPacket, false, ci.peerCert, trustedCAs) {
f.send(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out)
if f.lightHouse != nil && *ci.messageCounter%5000 == 0 {
f.lightHouse.Query(fwPacket.RemoteIP, f)
}
} else if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet")
}
}
func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
//if err != nil || hostinfo.ConnectionState == nil {
if err != nil {
hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIP(vpnIp)
if err != nil {
hostinfo = f.handshakeManager.AddVpnIP(vpnIp)
}
}
ci := hostinfo.ConnectionState
if ci != nil && ci.eKey != nil && ci.ready {
return hostinfo
}
if ci == nil {
// if we don't have a connection state, then send a handshake initiation
ci = f.newConnectionState(true, noise.HandshakeIX, []byte{}, 0)
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
//ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0)
hostinfo.ConnectionState = ci
} else if ci.eKey == nil {
// if we don't have any state at all, create it
}
// If we have already created the handshake packet, we don't want to call the function at all.
if !hostinfo.HandshakeReady {
ixHandshakeStage0(f, vpnIp, hostinfo)
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
//xx_handshakeStage0(f, ip, hostinfo)
}
return hostinfo
}
func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
fp := &FirewallPacket{}
err := newPacket(p, false, fp)
if err != nil {
l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
return
}
// check if packet is in outbound fw rules
if f.firewall.Drop(p, *fp, false, hostInfo.ConnectionState.peerCert, trustedCAs) {
l.WithField("fwPacket", fp).Debugln("dropping cached packet")
return
}
f.send(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
if f.lightHouse != nil && *hostInfo.ConnectionState.messageCounter%5000 == 0 {
f.lightHouse.Query(fp.RemoteIP, f)
}
}
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp)
if !hostInfo.ConnectionState.ready {
// Because we might be sending stored packets, lock here to stop new things going to
// the packet queue.
hostInfo.ConnectionState.queueLock.Lock()
if !hostInfo.ConnectionState.ready {
hostInfo.cachePacket(t, st, p, f.sendMessageToVpnIp)
hostInfo.ConnectionState.queueLock.Unlock()
return
}
hostInfo.ConnectionState.queueLock.Unlock()
}
f.sendMessageToVpnIp(t, st, hostInfo, p, nb, out)
return
}
func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
}
// SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp)
if hostInfo.ConnectionState.ready == false {
// Because we might be sending stored packets, lock here to stop new things going to
// the packet queue.
hostInfo.ConnectionState.queueLock.Lock()
if !hostInfo.ConnectionState.ready {
hostInfo.cachePacket(t, st, p, f.sendMessageToAll)
hostInfo.ConnectionState.queueLock.Unlock()
return
}
hostInfo.ConnectionState.queueLock.Unlock()
}
f.sendMessageToAll(t, st, hostInfo, p, nb, out)
return
}
func (f *Interface) sendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, b []byte) {
for _, r := range hostInfo.RemoteUDPAddrs() {
f.send(t, st, hostInfo.ConnectionState, hostInfo, r, p, nb, b)
}
}
func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
if ci.eKey == nil {
//TODO: log warning
return
}
var err error
//TODO: enable if we do more than 1 tun queue
//ci.writeLock.Lock()
c := atomic.AddUint64(ci.messageCounter, 1)
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out = HeaderEncode(out, Version, uint8(t), uint8(st), hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo.hostId)
out, err = ci.eKey.EncryptDanger(out, out, p, c, nb)
//TODO: see above note on lock
//ci.writeLock.Unlock()
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).
WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", ci.messageCounter).
Error("Failed to encrypt outgoing packet")
return
}
err = f.outside.WriteTo(out, remote)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
}
func isMulticast(ip uint32) bool {
// Class D multicast
if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 {
return true
}
return false
}

277
interface.go Normal file
View File

@ -0,0 +1,277 @@
package nebula
import (
"crypto/sha256"
"errors"
"fmt"
"io"
"os"
"time"
"github.com/rcrowley/go-metrics"
"golang.org/x/crypto/hkdf"
)
const mtu = 9001
type InterfaceConfig struct {
HostMap *HostMap
Outside *udpConn
Inside *Tun
certState *CertState
Cipher string
Firewall *Firewall
ServeDns bool
HandshakeManager *HandshakeManager
lightHouse *LightHouse
checkInterval int
pendingDeletionInterval int
handshakeMACKey string
handshakeAcceptedMACKeys []string
DropLocalBroadcast bool
DropMulticast bool
UDPBatchSize int
}
type Interface struct {
hostMap *HostMap
outside *udpConn
inside *Tun
certState *CertState
cipher string
firewall *Firewall
connectionManager *connectionManager
handshakeManager *HandshakeManager
serveDns bool
createTime time.Time
lightHouse *LightHouse
handshakeMACKey []byte
handshakeAcceptedMACKeys [][]byte
localBroadcast uint32
dropLocalBroadcast bool
dropMulticast bool
udpBatchSize int
version string
metricRxRecvError metrics.Counter
metricTxRecvError metrics.Counter
metricHandshakes metrics.Histogram
}
func NewInterface(c *InterfaceConfig) (*Interface, error) {
if c.Outside == nil {
return nil, errors.New("no outside connection")
}
if c.Inside == nil {
return nil, errors.New("no inside interface (tun)")
}
if c.certState == nil {
return nil, errors.New("no certificate state")
}
if c.Firewall == nil {
return nil, errors.New("no firewall rules")
}
// Use KDF to make this useful
hmacKey, err := sha256KdfFromString(c.handshakeMACKey)
if err != nil {
l.Debugln(err)
}
allowedMacs := make([][]byte, 0)
//allowedMacs = append(allowedMacs, mac)
if len(c.handshakeAcceptedMACKeys) > 0 {
for _, k := range c.handshakeAcceptedMACKeys {
// Use KDF to make these useful too
hmacKey, err := sha256KdfFromString(k)
if err != nil {
l.Debugln(err)
}
allowedMacs = append(allowedMacs, hmacKey)
}
} else {
if len(c.handshakeMACKey) > 0 {
l.Warnln("You have set an outgoing MAC but do not accept any incoming. This is probably not what you want.")
} else {
// This else is a fallback if we have not set any mac keys at all
hmacKey, err := sha256KdfFromString("")
if err != nil {
l.Debugln(err)
}
allowedMacs = append(allowedMacs, hmacKey)
}
}
ifce := &Interface{
hostMap: c.HostMap,
outside: c.Outside,
inside: c.Inside,
certState: c.certState,
cipher: c.Cipher,
firewall: c.Firewall,
serveDns: c.ServeDns,
handshakeManager: c.HandshakeManager,
createTime: time.Now(),
lightHouse: c.lightHouse,
handshakeMACKey: hmacKey,
handshakeAcceptedMACKeys: allowedMacs,
localBroadcast: ip2int(c.certState.certificate.Details.Ips[0].IP) | ^ip2int(c.certState.certificate.Details.Ips[0].Mask),
dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast,
udpBatchSize: c.UDPBatchSize,
metricRxRecvError: metrics.GetOrRegisterCounter("messages.rx.recv_error", nil),
metricTxRecvError: metrics.GetOrRegisterCounter("messages.tx.recv_error", nil),
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
}
ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval)
return ifce, nil
}
func (f *Interface) Run(tunRoutines, udpRoutines int, buildVersion string) {
// actually turn on tun dev
if err := f.inside.Activate(); err != nil {
l.Fatal(err)
}
f.version = buildVersion
l.WithField("interface", f.inside.Device).WithField("network", f.inside.Cidr.String()).
WithField("build", buildVersion).
Info("Nebula interface is active")
// Launch n queues to read packets from udp
for i := 0; i < udpRoutines; i++ {
go f.listenOut(i)
}
// Launch n queues to read packets from tun dev
for i := 0; i < tunRoutines; i++ {
go f.listenIn(i)
}
}
func (f *Interface) listenOut(i int) {
//TODO: handle error
addr, err := f.outside.LocalAddr()
if err != nil {
l.WithError(err).Error("failed to discover udp listening address")
}
var li *udpConn
if i > 0 {
//TODO: handle error
li, err = NewListener(udp2ip(addr).String(), int(addr.Port), i > 0)
if err != nil {
l.WithError(err).Error("failed to make a new udp listener")
}
} else {
li = f.outside
}
li.ListenOut(f)
}
func (f *Interface) listenIn(i int) {
packet := make([]byte, mtu)
out := make([]byte, mtu)
fwPacket := &FirewallPacket{}
nb := make([]byte, 12, 12)
for {
n, err := f.inside.Read(packet)
if err != nil {
l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out)
}
}
func (f *Interface) RegisterConfigChangeCallbacks(c *Config) {
c.RegisterReloadCallback(f.reloadCA)
c.RegisterReloadCallback(f.reloadCertKey)
c.RegisterReloadCallback(f.reloadFirewall)
c.RegisterReloadCallback(f.outside.reloadConfig)
}
func (f *Interface) reloadCA(c *Config) {
// reload and check regardless
// todo: need mutex?
newCAs, err := loadCAFromConfig(c)
if err != nil {
l.WithError(err).Error("Could not refresh trusted CA certificates")
return
}
trustedCAs = newCAs
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
}
func (f *Interface) reloadCertKey(c *Config) {
// reload and check in all cases
cs, err := NewCertStateFromConfig(c)
if err != nil {
l.WithError(err).Error("Could not refresh client cert")
return
}
// did IP in cert change? if so, don't set
oldIPs := f.certState.certificate.Details.Ips
newIPs := cs.certificate.Details.Ips
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
return
}
f.certState = cs
l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
}
func (f *Interface) reloadFirewall(c *Config) {
//TODO: need to trigger/detect if the certificate changed too
if c.HasChanged("firewall") == false {
l.Debug("No firewall config change detected")
return
}
fw, err := NewFirewallFromConfig(f.certState.certificate, c)
if err != nil {
l.WithError(err).Error("Error while creating firewall during reload")
return
}
oldFw := f.firewall
f.firewall = fw
oldFw.Destroy()
l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()).
Info("New firewall has been installed")
}
func (f *Interface) emitStats(i time.Duration) {
ticker := time.NewTicker(i)
for range ticker.C {
f.firewall.EmitStats()
f.handshakeManager.EmitStats()
}
}
func sha256KdfFromString(secret string) ([]byte, error) {
// Use KDF to make this useful
mac := []byte(secret)
hmacKey := make([]byte, sha256.BlockSize)
hash := sha256.New
hkdfer := hkdf.New(hash, []byte(mac), nil, nil)
n, err := io.ReadFull(hkdfer, hmacKey)
if n != len(hmacKey) || err != nil {
l.Errorln("KDF Failed!")
return nil, fmt.Errorf("%s", err)
}
return hmacKey, nil
}

368
lighthouse.go Normal file
View File

@ -0,0 +1,368 @@
package nebula
import (
"fmt"
"net"
"sync"
"time"
"github.com/golang/protobuf/proto"
"github.com/slackhq/nebula/cert"
)
type LightHouse struct {
sync.RWMutex //Because we concurrently read and write to our maps
amLighthouse bool
myIp uint32
punchConn *udpConn
// Local cache of answers from light houses
addrMap map[uint32][]udpAddr
// staticList exists to avoid having a bool in each addrMap entry
// since static should be rare
staticList map[uint32]struct{}
lighthouses map[uint32]struct{}
interval int
nebulaPort int
punchBack bool
}
type EncWriter interface {
SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
}
func NewLightHouse(amLighthouse bool, myIp uint32, ips []string, interval int, nebulaPort int, pc *udpConn, punchBack bool) *LightHouse {
h := LightHouse{
amLighthouse: amLighthouse,
myIp: myIp,
addrMap: make(map[uint32][]udpAddr),
nebulaPort: nebulaPort,
lighthouses: make(map[uint32]struct{}),
staticList: make(map[uint32]struct{}),
interval: interval,
punchConn: pc,
punchBack: punchBack,
}
for _, rIp := range ips {
h.lighthouses[ip2int(net.ParseIP(rIp))] = struct{}{}
}
return &h
}
func (lh *LightHouse) Query(ip uint32, f EncWriter) ([]udpAddr, error) {
if !lh.IsLighthouseIP(ip) {
lh.QueryServer(ip, f)
}
lh.RLock()
if v, ok := lh.addrMap[ip]; ok {
lh.RUnlock()
return v, nil
}
lh.RUnlock()
return nil, fmt.Errorf("host %s not known, queries sent to lighthouses", IntIp(ip))
}
// This is asynchronous so no reply should be expected
func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
if !lh.amLighthouse {
// Send a query to the lighthouses and hope for the best next time
query, err := proto.Marshal(NewLhQueryByInt(ip))
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
return
}
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for n := range lh.lighthouses {
f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
}
}
}
// Query our local lighthouse cached results
func (lh *LightHouse) QueryCache(ip uint32) []udpAddr {
lh.RLock()
if v, ok := lh.addrMap[ip]; ok {
lh.RUnlock()
return v
}
lh.RUnlock()
return nil
}
func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
// First we check the static mapping
// and do nothing if it is there
if _, ok := lh.staticList[vpnIP]; ok {
return
}
lh.Lock()
//l.Debugln(lh.addrMap)
delete(lh.addrMap, vpnIP)
l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
lh.Unlock()
}
func (lh *LightHouse) AddRemote(vpnIP uint32, toIp *udpAddr, static bool) {
// First we check if the sender thinks this is a static entry
// and do nothing if it is not, but should be considered static
if static == false {
if _, ok := lh.staticList[vpnIP]; ok {
return
}
}
lh.Lock()
for _, v := range lh.addrMap[vpnIP] {
if v.Equals(toIp) {
lh.Unlock()
return
}
}
//l.Debugf("Adding reply of %s as %s\n", IntIp(vpnIP), toIp)
if static {
lh.staticList[vpnIP] = struct{}{}
}
lh.addrMap[vpnIP] = append(lh.addrMap[vpnIP], *toIp)
lh.Unlock()
}
func (lh *LightHouse) AddRemoteAndReset(vpnIP uint32, toIp *udpAddr) {
if lh.amLighthouse {
lh.DeleteVpnIP(vpnIP)
lh.AddRemote(vpnIP, toIp, false)
}
}
func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool {
if _, ok := lh.lighthouses[vpnIP]; ok {
return true
}
return false
}
// Quick generators for protobuf
func NewLhQueryByIpString(VpnIp string) *NebulaMeta {
return NewLhQueryByInt(ip2int(net.ParseIP(VpnIp)))
}
func NewLhQueryByInt(VpnIp uint32) *NebulaMeta {
return &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
VpnIp: VpnIp,
},
}
}
func NewLhWhoami() *NebulaMeta {
return &NebulaMeta{
Type: NebulaMeta_HostWhoami,
Details: &NebulaMetaDetails{},
}
}
// End Quick generators for protobuf
func NewIpAndPortFromUDPAddr(addr udpAddr) *IpAndPort {
return &IpAndPort{Ip: udp2ipInt(&addr), Port: uint32(addr.Port)}
}
func NewIpAndPortsFromNetIps(ips []udpAddr) *[]*IpAndPort {
var iap []*IpAndPort
for _, e := range ips {
// Only add IPs that aren't my VPN/tun IP
iap = append(iap, NewIpAndPortFromUDPAddr(e))
}
return &iap
}
func (lh *LightHouse) LhUpdateWorker(f EncWriter) {
if lh.amLighthouse {
return
}
for {
ipp := []*IpAndPort{}
for _, e := range *localIps() {
// Only add IPs that aren't my VPN/tun IP
if ip2int(e) != lh.myIp {
ipp = append(ipp, &IpAndPort{Ip: ip2int(e), Port: uint32(lh.nebulaPort)})
//fmt.Println(e)
}
}
m := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{
VpnIp: lh.myIp,
IpAndPorts: ipp,
},
}
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for vpnIp := range lh.lighthouses {
mm, err := proto.Marshal(m)
if err != nil {
l.Debugf("Invalid marshal to update")
}
//l.Error("LIGHTHOUSE PACKET SEND", mm)
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
}
time.Sleep(time.Second * time.Duration(lh.interval))
}
}
func (lh *LightHouse) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, c *cert.NebulaCertificate, f EncWriter) {
n := &NebulaMeta{}
err := proto.Unmarshal(p, n)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
Error("Failed to unmarshal lighthouse packet")
//TODO: send recv_error?
return
}
if n.Details == nil {
l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
Error("Invalid lighthouse update")
//TODO: send recv_error?
return
}
switch n.Type {
case NebulaMeta_HostQuery:
// Exit if we don't answer queries
if !lh.amLighthouse {
l.Debugln("I don't answer queries, but received from: ", rAddr)
return
}
//l.Debugln("Got Query")
ips, err := lh.Query(n.Details.VpnIp, f)
if err != nil {
//l.Debugf("Can't answer query %s from %s because error: %s", IntIp(n.Details.VpnIp), rAddr, err)
return
} else {
iap := NewIpAndPortsFromNetIps(ips)
answer := &NebulaMeta{
Type: NebulaMeta_HostQueryReply,
Details: &NebulaMetaDetails{
VpnIp: n.Details.VpnIp,
IpAndPorts: *iap,
},
}
reply, err := proto.Marshal(answer)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
return
}
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
// This signals the other side to punch some zero byte udp packets
ips, err = lh.Query(vpnIp, f)
if err != nil {
l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch")
return
} else {
//l.Debugln("Notify host to punch", iap)
iap = NewIpAndPortsFromNetIps(ips)
answer = &NebulaMeta{
Type: NebulaMeta_HostPunchNotification,
Details: &NebulaMetaDetails{
VpnIp: vpnIp,
IpAndPorts: *iap,
},
}
reply, _ := proto.Marshal(answer)
f.SendMessageToVpnIp(lightHouse, 0, n.Details.VpnIp, reply, make([]byte, 12, 12), make([]byte, mtu))
}
//fmt.Println(reply, remoteaddr)
}
case NebulaMeta_HostQueryReply:
if !lh.IsLighthouseIP(vpnIp) {
return
}
for _, a := range n.Details.IpAndPorts {
//first := n.Details.IpAndPorts[0]
ans := NewUDPAddr(a.Ip, uint16(a.Port))
lh.AddRemote(n.Details.VpnIp, ans, false)
}
case NebulaMeta_HostUpdateNotification:
//Simple check that the host sent this not someone else
if n.Details.VpnIp != vpnIp {
l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
return
}
for _, a := range n.Details.IpAndPorts {
ans := NewUDPAddr(a.Ip, uint16(a.Port))
lh.AddRemote(n.Details.VpnIp, ans, false)
}
case NebulaMeta_HostMovedNotification:
case NebulaMeta_HostPunchNotification:
if !lh.IsLighthouseIP(vpnIp) {
return
}
empty := []byte{0}
for _, a := range n.Details.IpAndPorts {
vpnPeer := NewUDPAddr(a.Ip, uint16(a.Port))
go func() {
for i := 0; i < 5; i++ {
lh.punchConn.WriteTo(empty, vpnPeer)
time.Sleep(time.Second * 1)
}
}()
l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
}
// This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish
// a tunnel.
if lh.punchBack {
go func() {
time.Sleep(time.Second * 5)
l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
f.SendMessageToVpnIp(test, testRequest, n.Details.VpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}()
}
}
}
/*
func (f *Interface) sendPathCheck(ci *ConnectionState, endpoint *net.UDPAddr, counter int) {
c := ci.messageCounter
b := HeaderEncode(nil, Version, uint8(path_check), 0, ci.remoteIndex, c)
ci.messageCounter++
if ci.eKey != nil {
msg := ci.eKey.EncryptDanger(b, nil, []byte(strconv.Itoa(counter)), c)
//msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c)
f.outside.WriteTo(msg, endpoint)
l.Debugf("path_check sent, remote index: %d, pathCounter %d", ci.remoteIndex, counter)
}
}
func (f *Interface) sendPathCheckReply(ci *ConnectionState, endpoint *net.UDPAddr, counter []byte) {
c := ci.messageCounter
b := HeaderEncode(nil, Version, uint8(path_check_reply), 0, ci.remoteIndex, c)
ci.messageCounter++
if ci.eKey != nil {
msg := ci.eKey.EncryptDanger(b, nil, counter, c)
f.outside.WriteTo(msg, endpoint)
l.Debugln("path_check sent, remote index: ", ci.remoteIndex)
}
}
*/

76
lighthouse_test.go Normal file
View File

@ -0,0 +1,76 @@
package nebula
import (
"net"
"testing"
proto "github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
)
func TestNewLhQuery(t *testing.T) {
myIp := net.ParseIP("192.1.1.1")
myIpint := ip2int(myIp)
// Generating a new lh query should work
a := NewLhQueryByInt(myIpint)
// The result should be a nebulameta protobuf
assert.IsType(t, &NebulaMeta{}, a)
// It should also Marshal fine
b, err := proto.Marshal(a)
assert.Nil(t, err)
// and then Unmarshal fine
n := &NebulaMeta{}
err = proto.Unmarshal(b, n)
assert.Nil(t, err)
}
func TestNewipandportfromudpaddr(t *testing.T) {
blah := NewUDPAddrFromString("1.2.2.3:12345")
meh := NewIpAndPortFromUDPAddr(*blah)
assert.Equal(t, uint32(16908803), meh.Ip)
assert.Equal(t, uint32(12345), meh.Port)
}
func TestNewipandportsfromudpaddrs(t *testing.T) {
blah := NewUDPAddrFromString("1.2.2.3:12345")
blah2 := NewUDPAddrFromString("9.9.9.9:47828")
group := []udpAddr{*blah, *blah2}
hah := NewIpAndPortsFromNetIps(group)
assert.IsType(t, &[]*IpAndPort{}, hah)
//t.Error(reflect.TypeOf(hah))
}
/*
func TestLHQuery(t *testing.T) {
//n := NewLhQueryByIpString("10.128.0.3")
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
m := NewHostMap(myNet)
y, _ := net.ResolveUDPAddr("udp", "10.128.0.3:11111")
m.Add(ip2int(net.ParseIP("127.0.0.1")), y)
//t.Errorf("%s", m)
_ = m
_, n, _ := net.ParseCIDR("127.0.0.1/8")
/*udpServer, err := net.ListenUDP("udp", &net.UDPAddr{Port: 10009})
if err != nil {
t.Errorf("%s", err)
}
meh := NewLightHouse(n, m, []string{"10.128.0.2"}, false, 10, 10003, 10004)
//t.Error(m.Hosts)
meh2, err := meh.Query(ip2int(net.ParseIP("10.128.0.3")))
t.Error(err)
if err != nil {
return
}
t.Errorf("%s", meh2)
t.Errorf("%s", n)
}
*/

321
main.go Normal file
View File

@ -0,0 +1,321 @@
package nebula
import (
"encoding/binary"
"fmt"
"net"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
"github.com/slackhq/nebula/sshd"
)
var l = logrus.New()
type m map[string]interface{}
func Main(configPath string, configTest bool, buildVersion string) {
l.Out = os.Stdout
l.Formatter = &logrus.TextFormatter{
FullTimestamp: true,
}
config := NewConfig()
err := config.Load(configPath)
if err != nil {
l.WithError(err).Error("Failed to load config")
os.Exit(1)
}
// Print the config if in test, the exit comes later
if configTest {
b, err := yaml.Marshal(config.Settings)
if err != nil {
l.Println(err)
os.Exit(1)
}
l.Println(string(b))
}
err = configLogger(config)
if err != nil {
l.WithError(err).Error("Failed to configure the logger")
}
config.RegisterReloadCallback(func(c *Config) {
err := configLogger(c)
if err != nil {
l.WithError(err).Error("Failed to configure the logger")
}
})
// trustedCAs is currently a global, so loadCA operates on that global directly
trustedCAs, err = loadCAFromConfig(config)
if err != nil {
//The errors coming out of loadCA are already nicely formatted
l.Fatal(err)
}
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
cs, err := NewCertStateFromConfig(config)
if err != nil {
//The errors coming out of NewCertStateFromConfig are already nicely formatted
l.Fatal(err)
}
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
fw, err := NewFirewallFromConfig(cs.certificate, config)
if err != nil {
l.Fatal("Error while loading firewall rules: ", err)
}
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
// TODO: make sure mask is 4 bytes
tunCidr := cs.certificate.Details.Ips[0]
routes, err := parseRoutes(config, tunCidr)
if err != nil {
l.WithError(err).Fatal("Could not parse tun.routes")
}
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
wireSSHReload(ssh, config)
if config.GetBool("sshd.enabled", false) {
err = configSSH(ssh, config)
if err != nil {
l.WithError(err).Fatal("Error while configuring the sshd")
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// All non system modifying configuration consumption should live above this line
// tun config, listeners, anything modifying the computer should be below
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
if configTest {
os.Exit(0)
}
config.CatchHUP()
// set up our tun dev
tun, err := newTun(
config.GetString("tun.dev", ""),
tunCidr,
config.GetInt("tun.mtu", 1300),
routes,
config.GetInt("tun.tx_queue", 500),
)
if err != nil {
l.Fatal(err)
}
// set up our UDP listener
udpQueues := config.GetInt("listen.routines", 1)
udpServer, err := NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
if err != nil {
l.Fatal(err)
}
udpServer.reloadConfig(config)
// Set up my internal host map
var preferredRanges []*net.IPNet
rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
// First, check if 'preferred_ranges' is set and fallback to 'local_range'
if len(rawPreferredRanges) > 0 {
for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil {
l.Fatal(err)
}
preferredRanges = append(preferredRanges, preferredRange)
}
}
// local_range was superseded by preferred_ranges. If it is still present,
// merge the local_range setting into preferred_ranges. We will probably
// deprecate local_range and remove in the future.
rawLocalRange := config.GetString("local_range", "")
if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil {
l.Fatal(err)
}
// Check if the entry for local_range was already specified in
// preferred_ranges. Don't put it into the slice twice if so.
var found bool
for _, r := range preferredRanges {
if r.String() == localRange.String() {
found = true
break
}
}
if !found {
preferredRanges = append(preferredRanges, localRange)
}
}
hostMap := NewHostMap("main", tunCidr, preferredRanges)
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
/*
config.SetDefault("promoter.interval", 10)
go hostMap.Promoter(config.GetInt("promoter.interval"))
*/
punchy := config.GetBool("punchy", false)
if punchy == true {
l.Info("UDP hole punching enabled")
go hostMap.Punchy(udpServer)
}
port := config.GetInt("listen.port", 0)
// If port is dynamic, discover it
if port == 0 {
uPort, err := udpServer.LocalAddr()
if err != nil {
l.WithError(err).Fatal("Failed to get listening port")
}
port = int(uPort.Port)
}
punchBack := config.GetBool("punch_back", false)
amLighthouse := config.GetBool("lighthouse.am_lighthouse", false)
serveDns := config.GetBool("lighthouse.serve_dns", false)
lightHouse := NewLightHouse(
amLighthouse,
ip2int(tunCidr.IP),
config.GetStringSlice("lighthouse.hosts", []string{}),
//TODO: change to a duration
config.GetInt("lighthouse.interval", 10),
port,
udpServer,
punchBack,
)
if amLighthouse && serveDns {
l.Debugln("Starting dns server")
go dnsMain(hostMap)
}
for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
vals, ok := v.([]interface{})
if ok {
for _, v := range vals {
parts := strings.Split(fmt.Sprintf("%v", v), ":")
addr, err := net.ResolveIPAddr("ip", parts[0])
if err == nil {
ip := addr.IP
port, err := strconv.Atoi(parts[1])
if err != nil {
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v)
}
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
}
}
} else {
//TODO: make this all a helper
parts := strings.Split(fmt.Sprintf("%v", v), ":")
addr, err := net.ResolveIPAddr("ip", parts[0])
if err == nil {
ip := addr.IP
port, err := strconv.Atoi(parts[1])
if err != nil {
l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v)
}
lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
}
}
}
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpServer)
handshakeMACKey := config.GetString("handshake_mac.key", "")
handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
checkInterval := config.GetInt("timers.connection_alive_interval", 5)
pendingDeletionInterval := config.GetInt("timers.pending_deletion_interval", 10)
ifConfig := &InterfaceConfig{
HostMap: hostMap,
Inside: tun,
Outside: udpServer,
certState: cs,
Cipher: config.GetString("cipher", "aes"),
Firewall: fw,
ServeDns: serveDns,
HandshakeManager: handshakeManager,
lightHouse: lightHouse,
checkInterval: checkInterval,
pendingDeletionInterval: pendingDeletionInterval,
handshakeMACKey: handshakeMACKey,
handshakeAcceptedMACKeys: handshakeAcceptedMACKeys,
DropLocalBroadcast: config.GetBool("tun.drop_local_broadcast", false),
DropMulticast: config.GetBool("tun.drop_multicast", false),
UDPBatchSize: config.GetInt("listen.batch", 64),
}
switch ifConfig.Cipher {
case "aes":
noiseEndiannes = binary.BigEndian
case "chachapoly":
noiseEndiannes = binary.LittleEndian
default:
l.Fatalf("Unknown cipher: %v", ifConfig.Cipher)
}
ifce, err := NewInterface(ifConfig)
if err != nil {
l.Fatal(err)
}
ifce.RegisterConfigChangeCallbacks(config)
go handshakeManager.Run(ifce)
go lightHouse.LhUpdateWorker(ifce)
err = startStats(config)
if err != nil {
l.Fatal(err)
}
//TODO: check if we _should_ be emitting stats
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
ifce.Run(config.GetInt("tun.routines", 1), udpQueues, buildVersion)
// Just sit here and be friendly, main thread.
shutdownBlock(ifce)
}
func shutdownBlock(ifce *Interface) {
var sigChan = make(chan os.Signal)
signal.Notify(sigChan, syscall.SIGTERM)
signal.Notify(sigChan, syscall.SIGINT)
sig := <-sigChan
l.WithField("signal", sig).Info("Caught signal, shutting down")
//TODO: stop tun and udp routines, the lock on hostMap does effectively does that though
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
ifce.hostMap.Lock()
for _, h := range ifce.hostMap.Hosts {
if h.ConnectionState.ready {
ifce.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
}
}
ifce.hostMap.Unlock()
l.WithField("signal", sig).Info("Goodbye")
os.Exit(0)
}

1
main_test.go Normal file
View File

@ -0,0 +1 @@
package nebula

18
metadata.go Normal file
View File

@ -0,0 +1,18 @@
package nebula
/*
import (
proto "github.com/golang/protobuf/proto"
)
func HandleMetaProto(p []byte) {
m := &NebulaMeta{}
err := proto.Unmarshal(p, m)
if err != nil {
l.Debugf("problem unmarshaling meta message: %s", err)
}
//fmt.Println(m)
}
*/

457
nebula.pb.go Normal file
View File

@ -0,0 +1,457 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: nebula.proto
package nebula
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type NebulaMeta_MessageType int32
const (
NebulaMeta_None NebulaMeta_MessageType = 0
NebulaMeta_HostQuery NebulaMeta_MessageType = 1
NebulaMeta_HostQueryReply NebulaMeta_MessageType = 2
NebulaMeta_HostUpdateNotification NebulaMeta_MessageType = 3
NebulaMeta_HostMovedNotification NebulaMeta_MessageType = 4
NebulaMeta_HostPunchNotification NebulaMeta_MessageType = 5
NebulaMeta_HostWhoami NebulaMeta_MessageType = 6
NebulaMeta_HostWhoamiReply NebulaMeta_MessageType = 7
NebulaMeta_PathCheck NebulaMeta_MessageType = 8
NebulaMeta_PathCheckReply NebulaMeta_MessageType = 9
)
var NebulaMeta_MessageType_name = map[int32]string{
0: "None",
1: "HostQuery",
2: "HostQueryReply",
3: "HostUpdateNotification",
4: "HostMovedNotification",
5: "HostPunchNotification",
6: "HostWhoami",
7: "HostWhoamiReply",
8: "PathCheck",
9: "PathCheckReply",
}
var NebulaMeta_MessageType_value = map[string]int32{
"None": 0,
"HostQuery": 1,
"HostQueryReply": 2,
"HostUpdateNotification": 3,
"HostMovedNotification": 4,
"HostPunchNotification": 5,
"HostWhoami": 6,
"HostWhoamiReply": 7,
"PathCheck": 8,
"PathCheckReply": 9,
}
func (x NebulaMeta_MessageType) String() string {
return proto.EnumName(NebulaMeta_MessageType_name, int32(x))
}
func (NebulaMeta_MessageType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_2d65afa7693df5ef, []int{0, 0}
}
type NebulaPing_MessageType int32
const (
NebulaPing_Ping NebulaPing_MessageType = 0
NebulaPing_Reply NebulaPing_MessageType = 1
)
var NebulaPing_MessageType_name = map[int32]string{
0: "Ping",
1: "Reply",
}
var NebulaPing_MessageType_value = map[string]int32{
"Ping": 0,
"Reply": 1,
}
func (x NebulaPing_MessageType) String() string {
return proto.EnumName(NebulaPing_MessageType_name, int32(x))
}
func (NebulaPing_MessageType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_2d65afa7693df5ef, []int{3, 0}
}
type NebulaMeta struct {
Type NebulaMeta_MessageType `protobuf:"varint,1,opt,name=Type,json=type,proto3,enum=nebula.NebulaMeta_MessageType" json:"Type,omitempty"`
Details *NebulaMetaDetails `protobuf:"bytes,2,opt,name=Details,json=details,proto3" json:"Details,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *NebulaMeta) Reset() { *m = NebulaMeta{} }
func (m *NebulaMeta) String() string { return proto.CompactTextString(m) }
func (*NebulaMeta) ProtoMessage() {}
func (*NebulaMeta) Descriptor() ([]byte, []int) {
return fileDescriptor_2d65afa7693df5ef, []int{0}
}
func (m *NebulaMeta) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_NebulaMeta.Unmarshal(m, b)
}
func (m *NebulaMeta) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_NebulaMeta.Marshal(b, m, deterministic)
}
func (m *NebulaMeta) XXX_Merge(src proto.Message) {
xxx_messageInfo_NebulaMeta.Merge(m, src)
}
func (m *NebulaMeta) XXX_Size() int {
return xxx_messageInfo_NebulaMeta.Size(m)
}
func (m *NebulaMeta) XXX_DiscardUnknown() {
xxx_messageInfo_NebulaMeta.DiscardUnknown(m)
}
var xxx_messageInfo_NebulaMeta proto.InternalMessageInfo
func (m *NebulaMeta) GetType() NebulaMeta_MessageType {
if m != nil {
return m.Type
}
return NebulaMeta_None
}
func (m *NebulaMeta) GetDetails() *NebulaMetaDetails {
if m != nil {
return m.Details
}
return nil
}
type NebulaMetaDetails struct {
VpnIp uint32 `protobuf:"varint,1,opt,name=VpnIp,json=vpnIp,proto3" json:"VpnIp,omitempty"`
IpAndPorts []*IpAndPort `protobuf:"bytes,2,rep,name=IpAndPorts,json=ipAndPorts,proto3" json:"IpAndPorts,omitempty"`
Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *NebulaMetaDetails) Reset() { *m = NebulaMetaDetails{} }
func (m *NebulaMetaDetails) String() string { return proto.CompactTextString(m) }
func (*NebulaMetaDetails) ProtoMessage() {}
func (*NebulaMetaDetails) Descriptor() ([]byte, []int) {
return fileDescriptor_2d65afa7693df5ef, []int{1}
}
func (m *NebulaMetaDetails) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_NebulaMetaDetails.Unmarshal(m, b)
}
func (m *NebulaMetaDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_NebulaMetaDetails.Marshal(b, m, deterministic)
}
func (m *NebulaMetaDetails) XXX_Merge(src proto.Message) {
xxx_messageInfo_NebulaMetaDetails.Merge(m, src)
}
func (m *NebulaMetaDetails) XXX_Size() int {
return xxx_messageInfo_NebulaMetaDetails.Size(m)
}
func (m *NebulaMetaDetails) XXX_DiscardUnknown() {
xxx_messageInfo_NebulaMetaDetails.DiscardUnknown(m)
}
var xxx_messageInfo_NebulaMetaDetails proto.InternalMessageInfo
func (m *NebulaMetaDetails) GetVpnIp() uint32 {
if m != nil {
return m.VpnIp
}
return 0
}
func (m *NebulaMetaDetails) GetIpAndPorts() []*IpAndPort {
if m != nil {
return m.IpAndPorts
}
return nil
}
func (m *NebulaMetaDetails) GetCounter() uint32 {
if m != nil {
return m.Counter
}
return 0
}
type IpAndPort struct {
Ip uint32 `protobuf:"varint,1,opt,name=Ip,json=ip,proto3" json:"Ip,omitempty"`
Port uint32 `protobuf:"varint,2,opt,name=Port,json=port,proto3" json:"Port,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *IpAndPort) Reset() { *m = IpAndPort{} }
func (m *IpAndPort) String() string { return proto.CompactTextString(m) }
func (*IpAndPort) ProtoMessage() {}
func (*IpAndPort) Descriptor() ([]byte, []int) {
return fileDescriptor_2d65afa7693df5ef, []int{2}
}
func (m *IpAndPort) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_IpAndPort.Unmarshal(m, b)
}
func (m *IpAndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_IpAndPort.Marshal(b, m, deterministic)
}
func (m *IpAndPort) XXX_Merge(src proto.Message) {
xxx_messageInfo_IpAndPort.Merge(m, src)
}
func (m *IpAndPort) XXX_Size() int {
return xxx_messageInfo_IpAndPort.Size(m)
}
func (m *IpAndPort) XXX_DiscardUnknown() {
xxx_messageInfo_IpAndPort.DiscardUnknown(m)
}
var xxx_messageInfo_IpAndPort proto.InternalMessageInfo
func (m *IpAndPort) GetIp() uint32 {
if m != nil {
return m.Ip
}
return 0
}
func (m *IpAndPort) GetPort() uint32 {
if m != nil {
return m.Port
}
return 0
}
type NebulaPing struct {
Type NebulaPing_MessageType `protobuf:"varint,1,opt,name=Type,json=type,proto3,enum=nebula.NebulaPing_MessageType" json:"Type,omitempty"`
Time uint64 `protobuf:"varint,2,opt,name=Time,json=time,proto3" json:"Time,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *NebulaPing) Reset() { *m = NebulaPing{} }
func (m *NebulaPing) String() string { return proto.CompactTextString(m) }
func (*NebulaPing) ProtoMessage() {}
func (*NebulaPing) Descriptor() ([]byte, []int) {
return fileDescriptor_2d65afa7693df5ef, []int{3}
}
func (m *NebulaPing) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_NebulaPing.Unmarshal(m, b)
}
func (m *NebulaPing) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_NebulaPing.Marshal(b, m, deterministic)
}
func (m *NebulaPing) XXX_Merge(src proto.Message) {
xxx_messageInfo_NebulaPing.Merge(m, src)
}
func (m *NebulaPing) XXX_Size() int {
return xxx_messageInfo_NebulaPing.Size(m)
}
func (m *NebulaPing) XXX_DiscardUnknown() {
xxx_messageInfo_NebulaPing.DiscardUnknown(m)
}
var xxx_messageInfo_NebulaPing proto.InternalMessageInfo
func (m *NebulaPing) GetType() NebulaPing_MessageType {
if m != nil {
return m.Type
}
return NebulaPing_Ping
}
func (m *NebulaPing) GetTime() uint64 {
if m != nil {
return m.Time
}
return 0
}
type NebulaHandshake struct {
Details *NebulaHandshakeDetails `protobuf:"bytes,1,opt,name=Details,json=details,proto3" json:"Details,omitempty"`
Hmac []byte `protobuf:"bytes,2,opt,name=Hmac,json=hmac,proto3" json:"Hmac,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *NebulaHandshake) Reset() { *m = NebulaHandshake{} }
func (m *NebulaHandshake) String() string { return proto.CompactTextString(m) }
func (*NebulaHandshake) ProtoMessage() {}
func (*NebulaHandshake) Descriptor() ([]byte, []int) {
return fileDescriptor_2d65afa7693df5ef, []int{4}
}
func (m *NebulaHandshake) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_NebulaHandshake.Unmarshal(m, b)
}
func (m *NebulaHandshake) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_NebulaHandshake.Marshal(b, m, deterministic)
}
func (m *NebulaHandshake) XXX_Merge(src proto.Message) {
xxx_messageInfo_NebulaHandshake.Merge(m, src)
}
func (m *NebulaHandshake) XXX_Size() int {
return xxx_messageInfo_NebulaHandshake.Size(m)
}
func (m *NebulaHandshake) XXX_DiscardUnknown() {
xxx_messageInfo_NebulaHandshake.DiscardUnknown(m)
}
var xxx_messageInfo_NebulaHandshake proto.InternalMessageInfo
func (m *NebulaHandshake) GetDetails() *NebulaHandshakeDetails {
if m != nil {
return m.Details
}
return nil
}
func (m *NebulaHandshake) GetHmac() []byte {
if m != nil {
return m.Hmac
}
return nil
}
type NebulaHandshakeDetails struct {
Cert []byte `protobuf:"bytes,1,opt,name=Cert,json=cert,proto3" json:"Cert,omitempty"`
InitiatorIndex uint32 `protobuf:"varint,2,opt,name=InitiatorIndex,json=initiatorIndex,proto3" json:"InitiatorIndex,omitempty"`
ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,json=responderIndex,proto3" json:"ResponderIndex,omitempty"`
Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,json=cookie,proto3" json:"Cookie,omitempty"`
Time uint64 `protobuf:"varint,5,opt,name=Time,json=time,proto3" json:"Time,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} }
func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) }
func (*NebulaHandshakeDetails) ProtoMessage() {}
func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) {
return fileDescriptor_2d65afa7693df5ef, []int{5}
}
func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_NebulaHandshakeDetails.Unmarshal(m, b)
}
func (m *NebulaHandshakeDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_NebulaHandshakeDetails.Marshal(b, m, deterministic)
}
func (m *NebulaHandshakeDetails) XXX_Merge(src proto.Message) {
xxx_messageInfo_NebulaHandshakeDetails.Merge(m, src)
}
func (m *NebulaHandshakeDetails) XXX_Size() int {
return xxx_messageInfo_NebulaHandshakeDetails.Size(m)
}
func (m *NebulaHandshakeDetails) XXX_DiscardUnknown() {
xxx_messageInfo_NebulaHandshakeDetails.DiscardUnknown(m)
}
var xxx_messageInfo_NebulaHandshakeDetails proto.InternalMessageInfo
func (m *NebulaHandshakeDetails) GetCert() []byte {
if m != nil {
return m.Cert
}
return nil
}
func (m *NebulaHandshakeDetails) GetInitiatorIndex() uint32 {
if m != nil {
return m.InitiatorIndex
}
return 0
}
func (m *NebulaHandshakeDetails) GetResponderIndex() uint32 {
if m != nil {
return m.ResponderIndex
}
return 0
}
func (m *NebulaHandshakeDetails) GetCookie() uint64 {
if m != nil {
return m.Cookie
}
return 0
}
func (m *NebulaHandshakeDetails) GetTime() uint64 {
if m != nil {
return m.Time
}
return 0
}
func init() {
proto.RegisterEnum("nebula.NebulaMeta_MessageType", NebulaMeta_MessageType_name, NebulaMeta_MessageType_value)
proto.RegisterEnum("nebula.NebulaPing_MessageType", NebulaPing_MessageType_name, NebulaPing_MessageType_value)
proto.RegisterType((*NebulaMeta)(nil), "nebula.NebulaMeta")
proto.RegisterType((*NebulaMetaDetails)(nil), "nebula.NebulaMetaDetails")
proto.RegisterType((*IpAndPort)(nil), "nebula.IpAndPort")
proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing")
proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake")
proto.RegisterType((*NebulaHandshakeDetails)(nil), "nebula.NebulaHandshakeDetails")
}
func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) }
var fileDescriptor_2d65afa7693df5ef = []byte{
// 491 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x53, 0x41, 0x6f, 0xda, 0x4c,
0x10, 0x8d, 0x61, 0x81, 0x30, 0x80, 0xe3, 0xcc, 0xf7, 0x15, 0x91, 0x1e, 0xaa, 0xc8, 0x87, 0x8a,
0x13, 0x55, 0xc9, 0xa5, 0xd7, 0x8a, 0x1e, 0xe0, 0x00, 0xa2, 0x56, 0xda, 0x1e, 0xab, 0x8d, 0x3d,
0x8d, 0x57, 0xe0, 0xdd, 0x95, 0xbd, 0xa0, 0xf0, 0x8f, 0xfa, 0x63, 0x7a, 0xec, 0x0f, 0xaa, 0x76,
0x0d, 0xa6, 0x84, 0xa8, 0xb7, 0x7d, 0xf3, 0xde, 0xcc, 0x8e, 0xdf, 0x5b, 0x43, 0x57, 0xd2, 0xc3,
0x66, 0xcd, 0x47, 0x3a, 0x57, 0x46, 0x61, 0xb3, 0x44, 0xe1, 0xaf, 0x1a, 0xc0, 0xc2, 0x1d, 0xe7,
0x64, 0x38, 0x8e, 0x81, 0xdd, 0xef, 0x34, 0x0d, 0xbc, 0x5b, 0x6f, 0xe8, 0x8f, 0xdf, 0x8c, 0xf6,
0x3d, 0x47, 0xc5, 0x68, 0x4e, 0x45, 0xc1, 0x1f, 0xc9, 0xaa, 0x22, 0x66, 0x76, 0x9a, 0xf0, 0x0e,
0x5a, 0x9f, 0xc8, 0x70, 0xb1, 0x2e, 0x06, 0xb5, 0x5b, 0x6f, 0xd8, 0x19, 0xdf, 0x9c, 0xb7, 0xed,
0x05, 0x51, 0x2b, 0x29, 0x0f, 0xe1, 0x6f, 0x0f, 0x3a, 0x7f, 0x8d, 0xc2, 0x4b, 0x60, 0x0b, 0x25,
0x29, 0xb8, 0xc0, 0x1e, 0xb4, 0xa7, 0xaa, 0x30, 0x9f, 0x37, 0x94, 0xef, 0x02, 0x0f, 0x11, 0xfc,
0x0a, 0x46, 0xa4, 0xd7, 0xbb, 0xa0, 0x86, 0xaf, 0xa1, 0x6f, 0x6b, 0x5f, 0x74, 0xc2, 0x0d, 0x2d,
0x94, 0x11, 0x3f, 0x44, 0xcc, 0x8d, 0x50, 0x32, 0xa8, 0xe3, 0x0d, 0xbc, 0xb2, 0xdc, 0x5c, 0x6d,
0x29, 0x39, 0xa1, 0xd8, 0x81, 0x5a, 0x6e, 0x64, 0x9c, 0x9e, 0x50, 0x0d, 0xf4, 0x01, 0x2c, 0xf5,
0x2d, 0x55, 0x3c, 0x13, 0x41, 0x13, 0xff, 0x83, 0xab, 0x23, 0x2e, 0xaf, 0x6d, 0xd9, 0xcd, 0x96,
0xdc, 0xa4, 0x93, 0x94, 0xe2, 0x55, 0x70, 0x69, 0x37, 0xab, 0x60, 0x29, 0x69, 0x87, 0x5b, 0xb8,
0x3e, 0xfb, 0x68, 0xfc, 0x1f, 0x1a, 0x5f, 0xb5, 0x9c, 0x69, 0xe7, 0x6a, 0x2f, 0x6a, 0x6c, 0x2d,
0xc0, 0xf7, 0x00, 0x33, 0xfd, 0x51, 0x26, 0x4b, 0x95, 0x1b, 0xeb, 0x5c, 0x7d, 0xd8, 0x19, 0x5f,
0x1f, 0x9c, 0xab, 0x98, 0x08, 0x44, 0x25, 0xc2, 0x01, 0xb4, 0x62, 0xb5, 0x91, 0x86, 0xf2, 0x41,
0xdd, 0x8d, 0x3a, 0xc0, 0xf0, 0x1d, 0xb4, 0xab, 0x16, 0xf4, 0xa1, 0x56, 0x5d, 0x56, 0x13, 0x1a,
0x11, 0x98, 0xad, 0xbb, 0x74, 0x7a, 0x11, 0xd3, 0x2a, 0x37, 0xe1, 0xd3, 0x21, 0xf6, 0xa5, 0x90,
0x8f, 0xff, 0x8e, 0xdd, 0x2a, 0x5e, 0x88, 0x1d, 0x81, 0xdd, 0x8b, 0x8c, 0xdc, 0x54, 0x16, 0x31,
0x23, 0x32, 0x0a, 0xc3, 0xb3, 0x50, 0x6d, 0x73, 0x70, 0x81, 0x6d, 0x68, 0x94, 0x16, 0x79, 0xe1,
0x77, 0xb8, 0x2a, 0xe7, 0x4e, 0xb9, 0x4c, 0x8a, 0x94, 0xaf, 0x08, 0x3f, 0x1c, 0x5f, 0x90, 0xe7,
0x5e, 0xd0, 0xb3, 0x0d, 0x2a, 0xe5, 0xf3, 0x67, 0x64, 0x97, 0x98, 0x66, 0x3c, 0x76, 0x4b, 0x74,
0x23, 0x96, 0x66, 0x3c, 0x0e, 0x7f, 0x7a, 0xd0, 0x7f, 0xb9, 0xcf, 0xca, 0x27, 0x94, 0x1b, 0x77,
0x4b, 0x37, 0x62, 0x31, 0xe5, 0x06, 0xdf, 0x82, 0x3f, 0x93, 0xc2, 0x08, 0x6e, 0x54, 0x3e, 0x93,
0x09, 0x3d, 0xed, 0x7d, 0xf2, 0xc5, 0x49, 0xd5, 0xea, 0x22, 0x2a, 0xb4, 0x92, 0x09, 0xed, 0x75,
0x65, 0x06, 0x7e, 0x7e, 0x52, 0xc5, 0x3e, 0x34, 0x27, 0x4a, 0xad, 0x04, 0x0d, 0x98, 0x73, 0xa6,
0x19, 0x3b, 0x54, 0xf9, 0xd5, 0x38, 0xfa, 0xf5, 0xd0, 0x74, 0x3f, 0xe3, 0xdd, 0x9f, 0x00, 0x00,
0x00, 0xff, 0xff, 0x65, 0xc6, 0x25, 0x44, 0x9c, 0x03, 0x00, 0x00,
}

59
nebula.proto Normal file
View File

@ -0,0 +1,59 @@
syntax = "proto3";
package nebula;
message NebulaMeta {
enum MessageType {
None = 0;
HostQuery = 1;
HostQueryReply = 2;
HostUpdateNotification = 3;
HostMovedNotification = 4;
HostPunchNotification = 5;
HostWhoami = 6;
HostWhoamiReply = 7;
PathCheck = 8;
PathCheckReply = 9;
}
MessageType Type = 1;
NebulaMetaDetails Details = 2;
}
message NebulaMetaDetails {
uint32 VpnIp = 1;
repeated IpAndPort IpAndPorts = 2;
uint32 counter = 3;
}
message IpAndPort {
uint32 Ip = 1;
uint32 Port = 2;
}
message NebulaPing {
enum MessageType {
Ping = 0;
Reply = 1;
}
MessageType Type = 1;
uint64 Time = 2;
}
message NebulaHandshake {
NebulaHandshakeDetails Details = 1;
bytes Hmac = 2;
}
message NebulaHandshakeDetails {
bytes Cert = 1;
uint32 InitiatorIndex = 2;
uint32 ResponderIndex = 3;
uint64 Cookie = 4;
uint64 Time = 5;
}

60
noise.go Normal file
View File

@ -0,0 +1,60 @@
package nebula
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/flynn/noise"
)
type endiannes interface {
PutUint64(b []byte, v uint64)
}
var noiseEndiannes endiannes = binary.BigEndian
type NebulaCipherState struct {
c noise.Cipher
//k [32]byte
//n uint64
}
func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
return &NebulaCipherState{c: s.Cipher()}
}
func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
if s != nil {
// TODO: Is this okay now that we have made messageCounter atomic?
// Alternative may be to split the counter space into ranges
//if n <= s.n {
// return nil, errors.New("CRITICAL: a duplicate counter value was used")
//}
//s.n = n
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
noiseEndiannes.PutUint64(nb[4:], n)
out = s.c.(cipher.AEAD).Seal(out, nb, plaintext, ad)
//l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext))
return out, nil
} else {
return nil, errors.New("no cipher state available to encrypt")
}
}
func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) {
if s != nil {
nb[0] = 0
nb[1] = 0
nb[2] = 0
nb[3] = 0
noiseEndiannes.PutUint64(nb[4:], n)
return s.c.(cipher.AEAD).Open(out, nb, ciphertext, ad)
} else {
return []byte{}, nil
}
}

410
outside.go Normal file
View File

@ -0,0 +1,410 @@
package nebula
import (
"encoding/binary"
"github.com/flynn/noise"
"github.com/golang/protobuf/proto"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
// "github.com/google/gopacket"
// "github.com/google/gopacket/layers"
// "encoding/binary"
"errors"
"fmt"
"time"
"golang.org/x/net/ipv4"
)
const (
minFwPacketLen = 4
)
func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, nb []byte) {
err := header.Parse(packet)
if err != nil {
// TODO: best if we return this and let caller log
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 {
l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
}
return
}
//l.Error("in packet ", header, packet[HeaderLen:])
// verify if we've seen this index before, otherwise respond to the handshake initiation
hostinfo, err := f.hostMap.QueryIndex(header.RemoteIndex)
var ci *ConnectionState
if err == nil {
ci = hostinfo.ConnectionState
}
switch header.Type {
case message:
if !f.handleEncrypted(ci, addr, header) {
return
}
f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb)
// Fallthrough to the bottom to record incoming traffic
case lightHouse:
if !f.handleEncrypted(ci, addr, header) {
return
}
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
l.WithError(err).WithField("udpAddr", addr).WithField("vpnIp", IntIp(hostinfo.hostId)).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
return
}
f.lightHouse.HandleRequest(addr, hostinfo.hostId, d, hostinfo.GetCert(), f)
// Fallthrough to the bottom to record incoming traffic
case test:
if !f.handleEncrypted(ci, addr, header) {
return
}
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
l.WithError(err).WithField("udpAddr", addr).WithField("vpnIp", IntIp(hostinfo.hostId)).
WithField("packet", packet).
Error("Failed to decrypt test packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
return
}
if header.Subtype == testRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, addr)
f.send(test, testReply, ci, hostinfo, hostinfo.remote, d, nb, out)
}
// Fallthrough to the bottom to record incoming traffic
// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
// are unauthenticated
case handshake:
HandleIncomingHandshake(f, addr, packet, header, hostinfo)
return
case recvError:
// TODO: Remove this with recv_error deprecation
f.handleRecvError(addr, header)
return
case closeTunnel:
if !f.handleEncrypted(ci, addr, header) {
return
}
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
return
default:
l.Debugf("Unexpected packet received from %s", addr)
return
}
f.handleHostRoaming(hostinfo, addr)
f.connectionManager.In(hostinfo.hostId)
}
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
f.connectionManager.ClearIP(hostInfo.hostId)
f.connectionManager.ClearPendingDeletion(hostInfo.hostId)
f.lightHouse.DeleteVpnIP(hostInfo.hostId)
f.hostMap.DeleteVpnIP(hostInfo.hostId)
f.hostMap.DeleteIndex(hostInfo.localIndexId)
}
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
if hostDidRoam(hostinfo.remote, addr) {
if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSupressSeconds*time.Second {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
Debugf("Supressing roam back to previous remote for %d seconds", RoamingSupressSeconds)
}
return
}
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now()
remoteCopy := *hostinfo.remote
hostinfo.lastRoamRemote = &remoteCopy
hostinfo.SetRemote(*addr)
if f.lightHouse.amLighthouse {
f.lightHouse.AddRemote(hostinfo.hostId, addr, false)
}
}
}
func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool {
// If connectionstate exists and the replay protector allows, process packet
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
if ci == nil || !ci.window.Check(header.MessageCounter) {
f.sendRecvError(addr, header.RemoteIndex)
return false
}
return true
}
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
// Do we at least have an ipv4 header worth of data?
if len(data) < ipv4.HeaderLen {
return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
}
// Is it an ipv4 packet?
if int((data[0]>>4)&0x0f) != 4 {
return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f))
}
// Adjust our start position based on the advertised ip header length
ihl := int(data[0]&0x0f) << 2
// Well formed ip header length?
if ihl < ipv4.HeaderLen {
return fmt.Errorf("packet had an invalid header length: %v", ihl)
}
// Check if this is the second or further fragment of a fragmented packet.
flagsfrags := binary.BigEndian.Uint16(data[6:8])
fp.Fragment = (flagsfrags & 0x1FFF) != 0
// Firewall handles protocol checks
fp.Protocol = data[9]
// Accounting for a variable header length, do we have enough data for our src/dst tuples?
minLen := ihl
if !fp.Fragment && fp.Protocol != fwProtoICMP {
minLen += minFwPacketLen
}
if len(data) < minLen {
return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl)
}
// Firewall packets are locally oriented
if incoming {
fp.RemoteIP = binary.BigEndian.Uint32(data[12:16])
fp.LocalIP = binary.BigEndian.Uint32(data[16:20])
if fp.Fragment || fp.Protocol == fwProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
} else {
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2])
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
} else {
fp.LocalIP = binary.BigEndian.Uint32(data[12:16])
fp.RemoteIP = binary.BigEndian.Uint32(data[16:20])
if fp.Fragment || fp.Protocol == fwProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2])
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
}
return nil
}
func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, header *Header, nb []byte) ([]byte, error) {
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], mc, nb)
if err != nil {
return nil, err
}
if !hostinfo.ConnectionState.window.Update(mc) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("header", header).
Debugln("dropping out of window packet")
return nil, errors.New("out of window packet")
}
return out, nil
}
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
var err error
// TODO: This breaks subnet routing and needs to also check range of ip subnet
/*
if len(res) > 16 && binary.BigEndian.Uint32(res[12:16]) != ip2int(ci.peerCert.Details.Ips[0].IP) {
l.Debugf("Host %s tried to spoof packet as %s.", ci.peerCert.Details.Ips[0].IP, IntIp(binary.BigEndian.Uint32(res[12:16])))
}
*/
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).Error("Failed to decrypt packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
return
}
err = newPacket(out, true, fwPacket)
if err != nil {
l.WithError(err).WithField("packet", out).WithField("hostInfo", IntIp(hostinfo.hostId)).
Warnf("Error while validating inbound packet")
return
}
if !hostinfo.ConnectionState.window.Update(messageCounter) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket).
Debugln("dropping out of window packet")
return
}
if f.firewall.Drop(out, *fwPacket, true, hostinfo.ConnectionState.peerCert, trustedCAs) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket).
Debugln("dropping inbound packet")
return
}
f.connectionManager.In(hostinfo.hostId)
err = f.inside.WriteRaw(out)
if err != nil {
l.WithError(err).Error("Failed to write to tun")
}
}
func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
f.metricTxRecvError.Inc(1)
//TODO: this should be a signed message so we can trust that we should drop the index
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
f.outside.WriteTo(b, endpoint)
if l.Level >= logrus.DebugLevel {
l.WithField("index", index).
WithField("udpAddr", endpoint).
Debug("Recv error sent")
}
}
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
f.metricRxRecvError.Inc(1)
// This flag is to stop caring about recv_error from old versions
// This should go away when the old version is gone from prod
if l.Level >= logrus.DebugLevel {
l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr).
Debug("Recv error received")
}
hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex)
if err != nil {
l.Debugln(err, ": ", h.RemoteIndex)
return
}
if !hostinfo.RecvErrorExceeded() {
return
}
if hostinfo.remote != nil && hostinfo.remote.String() != addr.String() {
l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return
}
id := hostinfo.localIndexId
host := hostinfo.hostId
// We delete this host from the main hostmap
f.hostMap.DeleteIndex(id)
f.hostMap.DeleteVpnIP(host)
// We also delete it from pending to allow for
// fast reconnect. We must null the connectionstate
// or a counter reuse may happen
hostinfo.ConnectionState = nil
f.handshakeManager.DeleteIndex(id)
f.handshakeManager.DeleteVpnIP(host)
}
/*
func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *NebulaMeta) {
if ci.eKey != nil {
//TODO: log error?
return
}
msg, err := proto.Marshal(meta)
if err != nil {
l.Debugln("failed to encode header")
}
c := ci.messageCounter
b := HeaderEncode(nil, Version, uint8(metadata), 0, hostinfo.remoteIndexId, c)
ci.messageCounter++
msg := ci.eKey.EncryptDanger(b, nil, msg, c)
//msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c)
f.outside.WriteTo(msg, endpoint)
}
*/
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*cert.NebulaCertificate, error) {
pk := h.PeerStatic()
if pk == nil {
return nil, errors.New("no peer static key was present")
}
if rawCertBytes == nil {
return nil, errors.New("provided payload was empty")
}
r := &cert.RawNebulaCertificate{}
err := proto.Unmarshal(rawCertBytes, r)
if err != nil {
return nil, fmt.Errorf("error unmarshaling cert: %s", err)
}
// If the Details are nil, just exit to avoid crashing
if r.Details == nil {
return nil, fmt.Errorf("certificate did not contain any details")
}
r.Details.PublicKey = pk
recombined, err := proto.Marshal(r)
if err != nil {
return nil, fmt.Errorf("error while recombining certificate: %s", err)
}
c, _ := cert.UnmarshalNebulaCertificate(recombined)
isValid, err := c.Verify(time.Now(), trustedCAs)
if err != nil {
return c, fmt.Errorf("certificate validation failed: %s", err)
} else if !isValid {
// This case should never happen but here's to defensive programming!
return c, errors.New("certificate validation failed but did not return an error")
}
return c, nil
}

80
outside_test.go Normal file
View File

@ -0,0 +1,80 @@
package nebula
import (
"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
"net"
"testing"
)
func Test_newPacket(t *testing.T) {
p := &FirewallPacket{}
// length fail
err := newPacket([]byte{0, 1}, true, p)
assert.EqualError(t, err, "packet is less than 20 bytes")
// length fail with ip options
h := ipv4.Header{
Version: 1,
Len: 100,
Src: net.IPv4(10, 0, 0, 1),
Dst: net.IPv4(10, 0, 0, 2),
Options: []byte{0, 1, 0, 2},
}
b, _ := h.Marshal()
err = newPacket(b, true, p)
assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24")
// not an ipv4 packet
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.EqualError(t, err, "packet is not ipv4, type: 0")
// invalid ihl
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.EqualError(t, err, "packet had an invalid header length: 8")
// account for variable ip header length - incoming
h = ipv4.Header{
Version: 1,
Len: 100,
Src: net.IPv4(10, 0, 0, 1),
Dst: net.IPv4(10, 0, 0, 2),
Options: []byte{0, 1, 0, 2},
Protocol: fwProtoTCP,
}
b, _ = h.Marshal()
b = append(b, []byte{0, 3, 0, 4}...)
err = newPacket(b, true, p)
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(fwProtoTCP))
assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 2)))
assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 1)))
assert.Equal(t, p.RemotePort, uint16(3))
assert.Equal(t, p.LocalPort, uint16(4))
// account for variable ip header length - outgoing
h = ipv4.Header{
Version: 1,
Protocol: 2,
Len: 100,
Src: net.IPv4(10, 0, 0, 1),
Dst: net.IPv4(10, 0, 0, 2),
Options: []byte{0, 1, 0, 2},
}
b, _ = h.Marshal()
b = append(b, []byte{0, 5, 0, 6}...)
err = newPacket(b, false, p)
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(2))
assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 1)))
assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 2)))
assert.Equal(t, p.RemotePort, uint16(6))
assert.Equal(t, p.LocalPort, uint16(5))
}

727
ssh.go Normal file
View File

@ -0,0 +1,727 @@
package nebula
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"github.com/sirupsen/logrus"
"io/ioutil"
"net"
"os"
"reflect"
"runtime/pprof"
"github.com/slackhq/nebula/sshd"
"strings"
"syscall"
)
type sshListHostMapFlags struct {
Json bool
Pretty bool
}
type sshPrintCertFlags struct {
Json bool
Pretty bool
}
type sshPrintTunnelFlags struct {
Pretty bool
}
type sshChangeRemoteFlags struct {
Address string
}
type sshCloseTunnelFlags struct {
LocalOnly bool
}
type sshCreateTunnelFlags struct {
Address string
}
func wireSSHReload(ssh *sshd.SSHServer, c *Config) {
c.RegisterReloadCallback(func(c *Config) {
if c.GetBool("sshd.enabled", false) {
err := configSSH(ssh, c)
if err != nil {
l.WithError(err).Error("Failed to reconfigure the sshd")
ssh.Stop()
}
} else {
ssh.Stop()
}
})
}
func configSSH(ssh *sshd.SSHServer, c *Config) error {
//TODO conntrack list
//TODO print firewall rules or hash?
listen := c.GetString("sshd.listen", "")
if listen == "" {
return fmt.Errorf("sshd.listen must be provided")
}
port := strings.Split(listen, ":")
if len(port) < 2 {
return fmt.Errorf("sshd.listen does not have a port")
} else if port[1] == "22" {
return fmt.Errorf("sshd.listen can not use port 22")
}
//TODO: no good way to reload this right now
hostKeyFile := c.GetString("sshd.host_key", "")
if hostKeyFile == "" {
return fmt.Errorf("sshd.host_key must be provided")
}
hostKeyBytes, err := ioutil.ReadFile(hostKeyFile)
if err != nil {
return fmt.Errorf("error while loading sshd.host_key file: %s", err)
}
err = ssh.SetHostKey(hostKeyBytes)
if err != nil {
return fmt.Errorf("error while adding sshd.host_key: %s", err)
}
rawKeys := c.Get("sshd.authorized_users")
keys, ok := rawKeys.([]interface{})
if ok {
for _, rk := range keys {
kDef, ok := rk.(map[interface{}]interface{})
if !ok {
l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring")
continue
}
user, ok := kDef["user"].(string)
if !ok {
l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the user field")
continue
}
k := kDef["keys"]
switch v := k.(type) {
case string:
err := ssh.AddAuthorizedKey(user, v)
if err != nil {
l.WithError(err).WithField("sshKeyConfig", rk).WithField("sshKey", v).Warn("Failed to authorize key")
continue
}
case []interface{}:
for _, subK := range v {
sk, ok := subK.(string)
if !ok {
l.WithField("sshKeyConfig", rk).WithField("sshKey", subK).Warn("Did not understand ssh key")
continue
}
err := ssh.AddAuthorizedKey(user, sk)
if err != nil {
l.WithError(err).WithField("sshKeyConfig", sk).Warn("Failed to authorize key")
continue
}
}
default:
l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the keys field or was not understood")
}
}
} else {
l.Info("no ssh users to authorize")
}
if c.GetBool("sshd.enabled", false) {
ssh.Stop()
go ssh.Run(listen)
} else {
ssh.Stop()
}
return nil
}
func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
ssh.RegisterCommand(&sshd.Command{
Name: "list-hostmap",
ShortDescription: "List all known previously connected hosts",
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshListHostMapFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshListHostMap(hostMap, fs, w)
},
})
ssh.RegisterCommand(&sshd.Command{
Name: "list-pending-hostmap",
ShortDescription: "List all handshaking hosts",
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshListHostMapFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshListHostMap(pendingHostMap, fs, w)
},
})
ssh.RegisterCommand(&sshd.Command{
Name: "list-lighthouse-addrmap",
ShortDescription: "List all lighthouse map entries",
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshListHostMapFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshListLighthouseMap(lightHouse, fs, w)
},
})
ssh.RegisterCommand(&sshd.Command{
Name: "reload",
ShortDescription: "Reloads configuration from disk, same as sending HUP to the process",
Callback: sshReload,
})
ssh.RegisterCommand(&sshd.Command{
Name: "start-cpu-profile",
ShortDescription: "Starts a cpu profile and write output to the provided file",
Callback: sshStartCpuProfile,
})
ssh.RegisterCommand(&sshd.Command{
Name: "stop-cpu-profile",
ShortDescription: "Stops a cpu profile and writes output to the previously provided file",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
pprof.StopCPUProfile()
return w.WriteLine("If a CPU profile was running it is now stopped")
},
})
ssh.RegisterCommand(&sshd.Command{
Name: "save-heap-profile",
ShortDescription: "Saves a heap profile to the provided path",
Callback: sshGetHeapProfile,
})
ssh.RegisterCommand(&sshd.Command{
Name: "log-level",
ShortDescription: "Gets or sets the current log level",
Callback: sshLogLevel,
})
ssh.RegisterCommand(&sshd.Command{
Name: "log-format",
ShortDescription: "Gets or sets the current log format",
Callback: sshLogFormat,
})
ssh.RegisterCommand(&sshd.Command{
Name: "version",
ShortDescription: "Prints the currently running version of nebula",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshVersion(ifce, fs, a, w)
},
})
ssh.RegisterCommand(&sshd.Command{
Name: "print-cert",
ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn ip",
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshPrintCertFlags{}
fl.BoolVar(&s.Json, "json", false, "outputs as json")
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshPrintCert(ifce, fs, a, w)
},
})
ssh.RegisterCommand(&sshd.Command{
Name: "print-tunnel",
ShortDescription: "Prints json details about a tunnel for the provided vpn ip",
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshPrintTunnelFlags{}
fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshPrintTunnel(ifce, fs, a, w)
},
})
ssh.RegisterCommand(&sshd.Command{
Name: "change-remote",
ShortDescription: "Changes the remote address used in the tunnel for the provided vpn ip",
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshChangeRemoteFlags{}
fl.StringVar(&s.Address, "address", "", "The new remote address, ip:port")
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshChangeRemote(ifce, fs, a, w)
},
})
ssh.RegisterCommand(&sshd.Command{
Name: "close-tunnel",
ShortDescription: "Closes a tunnel for the provided vpn ip",
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshCloseTunnelFlags{}
fl.BoolVar(&s.LocalOnly, "local-only", false, "Disables notifying the remote that the tunnel is shutting down")
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshCloseTunnel(ifce, fs, a, w)
},
})
ssh.RegisterCommand(&sshd.Command{
Name: "create-tunnel",
ShortDescription: "Creates a tunnel for the provided vpn ip and address",
Help: "The lighthouses will be queried for real addresses but you can provide one as well.",
Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshCreateTunnelFlags{}
fl.StringVar(&s.Address, "address", "", "Optionally provide a real remote address, ip:port ")
return fl, &s
},
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshCreateTunnel(ifce, fs, a, w)
},
})
ssh.RegisterCommand(&sshd.Command{
Name: "query-lighthouse",
ShortDescription: "Query the lighthouses for the provided vpn ip",
Help: "This command is asynchronous. Only currently known udp ips will be printed.",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshQueryLighthouse(ifce, fs, a, w)
},
})
}
func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error {
fs, ok := a.(*sshListHostMapFlags)
if !ok {
//TODO: error
return nil
}
hostMap.RLock()
defer hostMap.RUnlock()
if fs.Json || fs.Pretty {
js := json.NewEncoder(w.GetWriter())
if fs.Pretty {
js.SetIndent("", " ")
}
d := make([]m, len(hostMap.Hosts))
x := 0
var h m
for _, v := range hostMap.Hosts {
h = m{
"vpnIp": int2ip(v.hostId),
"localIndex": v.localIndexId,
"remoteIndex": v.remoteIndexId,
"remoteAddrs": v.RemoteUDPAddrs(),
"cachedPackets": len(v.packetStore),
"cert": v.GetCert(),
}
if v.ConnectionState != nil {
h["messageCounter"] = v.ConnectionState.messageCounter
}
d[x] = h
x++
}
err := js.Encode(d)
if err != nil {
//TODO
return nil
}
} else {
for i, v := range hostMap.Hosts {
err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(i), v.RemoteUDPAddrs()))
if err != nil {
return err
}
}
}
return nil
}
func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error {
fs, ok := a.(*sshListHostMapFlags)
if !ok {
//TODO: error
return nil
}
lightHouse.RLock()
defer lightHouse.RUnlock()
if fs.Json || fs.Pretty {
js := json.NewEncoder(w.GetWriter())
if fs.Pretty {
js.SetIndent("", " ")
}
d := make([]m, len(lightHouse.addrMap))
x := 0
var h m
for vpnIp, v := range lightHouse.addrMap {
ips := make([]string, len(v))
for i, ip := range v {
ips[i] = ip.String()
}
h = m{
"vpnIp": int2ip(vpnIp),
"addrs": ips,
}
d[x] = h
x++
}
err := js.Encode(d)
if err != nil {
//TODO
return nil
}
} else {
for vpnIp, v := range lightHouse.addrMap {
ips := make([]string, len(v))
for i, ip := range v {
ips[i] = ip.String()
}
err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(vpnIp), ips))
if err != nil {
return err
}
}
}
return nil
}
func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
err := w.WriteLine("No path to write profile provided")
return err
}
file, err := os.Create(a[0])
if err != nil {
err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err))
return err
}
err = pprof.StartCPUProfile(file)
if err != nil {
err = w.WriteLine(fmt.Sprintf("Unable to start cpu profile: %s", err))
return err
}
err = w.WriteLine(fmt.Sprintf("Started cpu profile, issue stop-cpu-profile to write the output to %s", a))
return err
}
func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
return w.WriteLine(fmt.Sprintf("%s", ifce.version))
}
func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine("No vpn ip was provided")
}
vpnIp := ip2int(net.ParseIP(a[0]))
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
ips, _ := ifce.lightHouse.Query(vpnIp, ifce)
return json.NewEncoder(w.GetWriter()).Encode(ips)
}
func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
flags, ok := fs.(*sshCloseTunnelFlags)
if !ok {
//TODO: error
return nil
}
if len(a) == 0 {
return w.WriteLine("No vpn ip was provided")
}
vpnIp := ip2int(net.ParseIP(a[0]))
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
if !flags.LocalOnly {
ifce.send(
closeTunnel,
0,
hostInfo.ConnectionState,
hostInfo,
hostInfo.remote,
[]byte{},
make([]byte, 12, 12),
make([]byte, mtu),
)
}
ifce.closeTunnel(hostInfo)
return w.WriteLine("Closed")
}
func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
flags, ok := fs.(*sshCreateTunnelFlags)
if !ok {
//TODO: error
return nil
}
if len(a) == 0 {
return w.WriteLine("No vpn ip was provided")
}
vpnIp := ip2int(net.ParseIP(a[0]))
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, _ := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
}
hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIP(uint32(vpnIp))
if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
}
var addr *udpAddr
if flags.Address != "" {
addr = NewUDPAddrFromString(flags.Address)
if addr == nil {
return w.WriteLine("Address could not be parsed")
}
}
hostInfo = ifce.handshakeManager.AddVpnIP(vpnIp)
if addr != nil {
hostInfo.SetRemote(*addr)
}
ifce.getOrHandshake(vpnIp)
return w.WriteLine("Created")
}
func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
flags, ok := fs.(*sshChangeRemoteFlags)
if !ok {
//TODO: error
return nil
}
if len(a) == 0 {
return w.WriteLine("No vpn ip was provided")
}
if flags.Address == "" {
return w.WriteLine("No address was provided")
}
addr := NewUDPAddrFromString(flags.Address)
if addr == nil {
return w.WriteLine("Address could not be parsed")
}
vpnIp := ip2int(net.ParseIP(a[0]))
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
hostInfo.SetRemote(*addr)
return w.WriteLine("Changed")
}
func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine("No path to write profile provided")
}
file, err := os.Create(a[0])
if err != nil {
err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err))
return err
}
err = pprof.WriteHeapProfile(file)
if err != nil {
err = w.WriteLine(fmt.Sprintf("Unable to write profile: %s", err))
return err
}
err = w.WriteLine(fmt.Sprintf("Mem profile created at %s", a))
return err
}
func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
}
level, err := logrus.ParseLevel(a[0])
if err != nil {
return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: %s", a, logrus.AllLevels))
}
l.SetLevel(level)
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
}
func sshLogFormat(fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
}
logFormat := strings.ToLower(a[0])
switch logFormat {
case "text":
l.Formatter = &logrus.TextFormatter{}
case "json":
l.Formatter = &logrus.JSONFormatter{}
default:
return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
}
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
}
func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
args, ok := fs.(*sshPrintCertFlags)
if !ok {
//TODO: error
return nil
}
cert := ifce.certState.certificate
if len(a) > 0 {
vpnIp := ip2int(net.ParseIP(a[0]))
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
cert = hostInfo.GetCert()
}
if args.Json || args.Pretty {
b, err := cert.MarshalJSON()
if err != nil {
//TODO: handle it
return nil
}
if args.Pretty {
buf := new(bytes.Buffer)
err := json.Indent(buf, b, "", " ")
b = buf.Bytes()
if err != nil {
//TODO: handle it
return nil
}
}
return w.WriteBytes(b)
}
return w.WriteLine(cert.String())
}
func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
args, ok := fs.(*sshPrintTunnelFlags)
if !ok {
//TODO: error
return nil
}
if len(a) == 0 {
return w.WriteLine("No vpn ip was provided")
}
vpnIp := ip2int(net.ParseIP(a[0]))
if vpnIp == 0 {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
if err != nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
enc := json.NewEncoder(w.GetWriter())
if args.Pretty {
enc.SetIndent("", " ")
}
return enc.Encode(hostInfo)
}
func sshReload(fs interface{}, a []string, w sshd.StringWriter) error {
p, err := os.FindProcess(os.Getpid())
if err != nil {
return w.WriteLine(err.Error())
//TODO
}
err = p.Signal(syscall.SIGHUP)
if err != nil {
return w.WriteLine(err.Error())
//TODO
}
return w.WriteLine("HUP sent")
}

161
sshd/command.go Normal file
View File

@ -0,0 +1,161 @@
package sshd
import (
"errors"
"flag"
"fmt"
"github.com/armon/go-radix"
"sort"
"strings"
)
// CommandFlags is a function called before help or command execution to parse command line flags
// It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags
type CommandFlags func() (*flag.FlagSet, interface{})
// CommandCallback is the function called when your command should execute.
// fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved
// and handled automatically for you.
// a will be any unconsumed arguments, if no Command.Flags was available this will be all the flags passed in.
// w is the writer to use when sending messages back to the client.
// If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user
// where appropriate
type CommandCallback func(fs interface{}, a []string, w StringWriter) error
type Command struct {
Name string
ShortDescription string
Help string
Flags CommandFlags
Callback CommandCallback
}
func execCommand(c *Command, args []string, w StringWriter) error {
var (
fl *flag.FlagSet
fs interface{}
)
if c.Flags != nil {
fl, fs = c.Flags()
if fl != nil {
//TODO: handle the error
fl.Parse(args)
args = fl.Args()
}
}
return c.Callback(fs, args, w)
}
func dumpCommands(c *radix.Tree, w StringWriter) {
err := w.WriteLine("Available commands:")
if err != nil {
//TODO: log
return
}
cmds := make([]string, 0)
for _, l := range allCommands(c) {
cmds = append(cmds, fmt.Sprintf("%s - %s", l.Name, l.ShortDescription))
}
sort.Strings(cmds)
err = w.Write(strings.Join(cmds, "\n") + "\n\n")
if err != nil {
//TODO: log
}
}
func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) {
cmd, ok := c.Get(sCmd)
if !ok {
return nil, nil
}
command, ok := cmd.(*Command)
if !ok {
return nil, errors.New("failed to cast command")
}
return command, nil
}
func matchCommand(c *radix.Tree, cmd string) []string {
cmds := make([]string, 0)
c.WalkPrefix(cmd, func(found string, v interface{}) bool {
cmds = append(cmds, found)
return false
})
sort.Strings(cmds)
return cmds
}
func allCommands(c *radix.Tree) []*Command {
cmds := make([]*Command, 0)
c.WalkPrefix("", func(found string, v interface{}) bool {
cmd, ok := v.(*Command)
if ok {
cmds = append(cmds, cmd)
}
return false
})
return cmds
}
func helpCallback(commands *radix.Tree, a []string, w StringWriter) (err error) {
// Just typed help
if len(a) == 0 {
dumpCommands(commands, w)
return nil
}
// We are printing a specific commands help text
cmd, err := lookupCommand(commands, a[0])
if err != nil {
//TODO: handle error
//TODO: message the user
return
}
if cmd != nil {
err = w.WriteLine(fmt.Sprintf("%s - %s", cmd.Name, cmd.ShortDescription))
if err != nil {
return err
}
if cmd.Help != "" {
err = w.WriteLine(fmt.Sprintf(" %s", cmd.Help))
if err != nil {
return err
}
}
if cmd.Flags != nil {
fs, _ := cmd.Flags()
if fs != nil {
fs.SetOutput(w.GetWriter())
fs.PrintDefaults()
}
}
return nil
}
err = w.WriteLine("Command not available " + a[0])
if err != nil {
return err
}
return nil
}
func checkHelpArgs(args []string) bool {
for _, a := range args {
if a == "-h" || a == "-help" {
return true
}
}
return false
}

182
sshd/server.go Normal file
View File

@ -0,0 +1,182 @@
package sshd
import (
"fmt"
"github.com/armon/go-radix"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"net"
)
type SSHServer struct {
config *ssh.ServerConfig
l *logrus.Entry
// Map of user -> authorized keys
trustedKeys map[string]map[string]bool
// List of available commands
helpCommand *Command
commands *radix.Tree
listener net.Listener
conns map[int]*session
counter int
}
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
s := &SSHServer{
trustedKeys: make(map[string]map[string]bool),
l: l,
commands: radix.New(),
conns: make(map[int]*session),
}
s.config = &ssh.ServerConfig{
PublicKeyCallback: s.matchPubKey,
//TODO: AuthLogCallback: s.authAttempt,
//TODO: version string
ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"),
}
s.RegisterCommand(&Command{
Name: "help",
ShortDescription: "prints available commands or help <command> for specific usage info",
Callback: func(a interface{}, args []string, w StringWriter) error {
return helpCallback(s.commands, args, w)
},
})
return s, nil
}
func (s *SSHServer) SetHostKey(hostPrivateKey []byte) error {
private, err := ssh.ParsePrivateKey(hostPrivateKey)
if err != nil {
return fmt.Errorf("failed to parse private key: %s", err)
}
s.config.AddHostKey(private)
return nil
}
func (s *SSHServer) ClearAuthorizedKeys() {
s.trustedKeys = make(map[string]map[string]bool)
}
// AddAuthorizedKey adds an ssh public key for a user
func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error {
pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey))
if err != nil {
return err
}
tk, ok := s.trustedKeys[user]
if !ok {
tk = make(map[string]bool)
s.trustedKeys[user] = tk
}
tk[string(pk.Marshal())] = true
s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key")
return nil
}
// RegisterCommand adds a command that can be run by a user, by default only `help` is available
func (s *SSHServer) RegisterCommand(c *Command) {
s.commands.Insert(c.Name, c)
}
// Run begins listening and accepting connections
func (s *SSHServer) Run(addr string) error {
var err error
s.listener, err = net.Listen("tcp", addr)
if err != nil {
return err
}
s.l.WithField("sshListener", addr).Info("SSH server is listening")
for {
c, err := s.listener.Accept()
if err != nil {
s.l.WithError(err).Warn("Error in listener, shutting down")
return nil
}
conn, chans, reqs, err := ssh.NewServerConn(c, s.config)
fp := ""
if conn != nil {
fp = conn.Permissions.Extensions["fp"]
}
if err != nil {
l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr())
if conn != nil {
l = l.WithField("sshUser", conn.User())
conn.Close()
}
if fp != "" {
l = l.WithField("sshFingerprint", fp)
}
l.Warn("failed to handshake")
continue
}
l := s.l.WithField("sshUser", conn.User())
l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session"))
s.counter++
counter := s.counter
s.conns[counter] = session
go ssh.DiscardRequests(reqs)
go func() {
<-session.exitChan
s.l.WithField("id", counter).Debug("closing conn")
delete(s.conns, counter)
}()
}
}
func (s *SSHServer) Stop() {
for _, c := range s.conns {
c.Close()
}
if s.listener == nil {
return
}
err := s.listener.Close()
if err != nil {
s.l.WithError(err).Warn("Failed to close the sshd listener")
return
}
s.l.Info("SSH server stopped listening")
return
}
func (s *SSHServer) matchPubKey(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
pk := string(pubKey.Marshal())
fp := ssh.FingerprintSHA256(pubKey)
tk, ok := s.trustedKeys[c.User()]
if !ok {
return nil, fmt.Errorf("unknown user %s", c.User())
}
_, ok = tk[pk]
if !ok {
return nil, fmt.Errorf("unknown public key for %s (%s)", c.User(), fp)
}
return &ssh.Permissions{
// Record the public key used for authentication.
Extensions: map[string]string{
"fp": fp,
"user": c.User(),
},
}, nil
}

182
sshd/session.go Normal file
View File

@ -0,0 +1,182 @@
package sshd
import (
"fmt"
"github.com/anmitsu/go-shlex"
"github.com/armon/go-radix"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
"sort"
"strings"
)
type session struct {
l *logrus.Entry
c *ssh.ServerConn
term *terminal.Terminal
commands *radix.Tree
exitChan chan bool
}
func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session {
s := &session{
commands: radix.NewFromMap(commands.ToMap()),
l: l,
c: conn,
exitChan: make(chan bool),
}
s.commands.Insert("logout", &Command{
Name: "logout",
ShortDescription: "Ends the current session",
Callback: func(a interface{}, args []string, w StringWriter) error {
s.Close()
return nil
},
})
go s.handleChannels(chans)
return s
}
func (s *session) handleChannels(chans <-chan ssh.NewChannel) {
for newChannel := range chans {
if newChannel.ChannelType() != "session" {
s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type")
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
s.l.WithError(err).Warn("could not accept channel")
continue
}
go s.handleRequests(requests, channel)
}
}
func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
for req := range in {
var err error
//TODO: maybe support window sizing?
switch req.Type {
case "shell":
if s.term == nil {
s.term = s.createTerm(channel)
err = req.Reply(true, nil)
} else {
err = req.Reply(false, nil)
}
case "pty-req":
err = req.Reply(true, nil)
case "window-change":
err = req.Reply(true, nil)
case "exec":
var payload = struct{ Value string }{}
cErr := ssh.Unmarshal(req.Payload, &payload)
if cErr == nil {
s.dispatchCommand(payload.Value, &stringWriter{channel})
} else {
//TODO: log it
}
channel.Close()
return
default:
s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request")
err = req.Reply(false, nil)
}
if err != nil {
s.l.WithError(err).Info("Error handling ssh session requests")
s.Close()
return
}
}
}
func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal {
//TODO: PS1 with nebula cert name
term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ")
term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
// key 9 is tab
if key == 9 {
cmds := matchCommand(s.commands, line)
if len(cmds) == 1 {
return cmds[0] + " ", len(cmds[0]) + 1, true
}
sort.Strings(cmds)
term.Write([]byte(strings.Join(cmds, "\n") + "\n\n"))
}
return "", 0, false
}
go s.handleInput(channel)
return term
}
func (s *session) handleInput(channel ssh.Channel) {
defer s.Close()
w := &stringWriter{w: s.term}
for {
line, err := s.term.ReadLine()
if err != nil {
//TODO: log
break
}
s.dispatchCommand(line, w)
}
}
func (s *session) dispatchCommand(line string, w StringWriter) {
args, err := shlex.Split(line, true)
if err != nil {
//todo: LOG IT
return
}
if len(args) == 0 {
dumpCommands(s.commands, w)
return
}
c, err := lookupCommand(s.commands, args[0])
if err != nil {
//TODO: handle the error
return
}
if c == nil {
err := w.WriteLine(fmt.Sprintf("did not understand: %s", line))
//TODO: log error
_ = err
dumpCommands(s.commands, w)
return
}
if checkHelpArgs(args) {
s.dispatchCommand(fmt.Sprintf("%s %s", "help", c.Name), w)
return
}
err = execCommand(c, args[1:], w)
if err != nil {
//TODO: log the error
}
return
}
func (s *session) Close() {
s.c.Close()
s.exitChan <- true
}

32
sshd/writer.go Normal file
View File

@ -0,0 +1,32 @@
package sshd
import "io"
type StringWriter interface {
WriteLine(string) error
Write(string) error
WriteBytes([]byte) error
GetWriter() io.Writer
}
type stringWriter struct {
w io.Writer
}
func (w *stringWriter) WriteLine(s string) error {
return w.Write(s + "\n")
}
func (w *stringWriter) Write(s string) error {
_, err := w.w.Write([]byte(s))
return err
}
func (w *stringWriter) WriteBytes(b []byte) error {
_, err := w.w.Write(b)
return err
}
func (w *stringWriter) GetWriter() io.Writer {
return w.w
}

89
stats.go Normal file
View File

@ -0,0 +1,89 @@
package nebula
import (
"errors"
"fmt"
"github.com/cyberdelia/go-metrics-graphite"
mp "github.com/nbrownus/go-metrics-prometheus"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rcrowley/go-metrics"
"log"
"net"
"net/http"
"time"
)
func startStats(c *Config) error {
mType := c.GetString("stats.type", "")
if mType == "" || mType == "none" {
return nil
}
interval := c.GetDuration("stats.interval", 0)
if interval == 0 {
return fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", ""))
}
switch mType {
case "graphite":
startGraphiteStats(interval, c)
case "prometheus":
startPrometheusStats(interval, c)
default:
return fmt.Errorf("stats.type was not understood: %s", mType)
}
metrics.RegisterDebugGCStats(metrics.DefaultRegistry)
metrics.RegisterRuntimeMemStats(metrics.DefaultRegistry)
go metrics.CaptureDebugGCStats(metrics.DefaultRegistry, interval)
go metrics.CaptureRuntimeMemStats(metrics.DefaultRegistry, interval)
return nil
}
func startGraphiteStats(i time.Duration, c *Config) error {
proto := c.GetString("stats.protocol", "tcp")
host := c.GetString("stats.host", "")
if host == "" {
return errors.New("stats.host can not be empty")
}
prefix := c.GetString("stats.prefix", "nebula")
addr, err := net.ResolveTCPAddr(proto, host)
if err != nil {
return fmt.Errorf("error while setting up graphite sink: %s", err)
}
l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
return nil
}
func startPrometheusStats(i time.Duration, c *Config) error {
namespace := c.GetString("stats.namespace", "")
subsystem := c.GetString("stats.subsystem", "")
listen := c.GetString("stats.listen", "")
if listen == "" {
return fmt.Errorf("stats.listen should not be emtpy")
}
path := c.GetString("stats.path", "")
if path == "" {
return fmt.Errorf("stats.path should not be emtpy")
}
pr := prometheus.NewRegistry()
pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i)
go pClient.UpdatePrometheusMetrics()
go func() {
l.Infof("Prometheus stats listening on %s at %s", listen, path)
http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l}))
log.Fatal(http.ListenAndServe(listen, nil))
}()
return nil
}

191
timeout.go Normal file
View File

@ -0,0 +1,191 @@
package nebula
import (
"time"
)
// How many timer objects should be cached
const timerCacheMax = 50000
var emptyFWPacket = FirewallPacket{}
type TimerWheel struct {
// Current tick
current int
// Cheat on finding the length of the wheel
wheelLen int
// Last time we ticked, since we are lazy ticking
lastTick *time.Time
// Durations of a tick and the entire wheel
tickDuration time.Duration
wheelDuration time.Duration
// The actual wheel which is just a set of singly linked lists, head/tail pointers
wheel []*TimeoutList
// Singly linked list of items that have timed out of the wheel
expired *TimeoutList
// Item cache to avoid garbage collect
itemCache *TimeoutItem
itemsCached int
}
// Represents a tick in the wheel
type TimeoutList struct {
Head *TimeoutItem
Tail *TimeoutItem
}
// Represents an item within a tick
type TimeoutItem struct {
Packet FirewallPacket
Next *TimeoutItem
}
// Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
// Purge must be called once per entry to actually remove anything
func NewTimerWheel(min, max time.Duration) *TimerWheel {
//TODO provide an error
//if min >= max {
// return nil
//}
// Round down and add 1 so we can have the smallest # of ticks in the wheel and still account for a full
// max duration
wLen := int((max / min) + 1)
tw := TimerWheel{
wheelLen: wLen,
wheel: make([]*TimeoutList, wLen),
tickDuration: min,
wheelDuration: max,
expired: &TimeoutList{},
}
for i := range tw.wheel {
tw.wheel[i] = &TimeoutList{}
}
return &tw
}
// Add will add a FirewallPacket to the wheel in it's proper timeout
func (tw *TimerWheel) Add(v FirewallPacket, timeout time.Duration) *TimeoutItem {
// Check and see if we should progress the tick
tw.advance(time.Now())
i := tw.findWheel(timeout)
// Try to fetch off the cache
ti := tw.itemCache
if ti != nil {
tw.itemCache = ti.Next
tw.itemsCached--
ti.Next = nil
} else {
ti = &TimeoutItem{}
}
// Relink and return
ti.Packet = v
if tw.wheel[i].Tail == nil {
tw.wheel[i].Head = ti
tw.wheel[i].Tail = ti
} else {
tw.wheel[i].Tail.Next = ti
tw.wheel[i].Tail = ti
}
return ti
}
func (tw *TimerWheel) Purge() (FirewallPacket, bool) {
if tw.expired.Head == nil {
return emptyFWPacket, false
}
ti := tw.expired.Head
tw.expired.Head = ti.Next
if tw.expired.Head == nil {
tw.expired.Tail = nil
}
// Clear out the items references
ti.Next = nil
// Maybe cache it for later
if tw.itemsCached < timerCacheMax {
ti.Next = tw.itemCache
tw.itemCache = ti
tw.itemsCached++
}
return ti.Packet, true
}
// advance will move the wheel forward by proper number of ticks. The caller _should_ lock the wheel before calling this
func (tw *TimerWheel) findWheel(timeout time.Duration) (i int) {
if timeout < tw.tickDuration {
// Can't track anything below the set resolution
timeout = tw.tickDuration
} else if timeout > tw.wheelDuration {
// We aren't handling timeouts greater than the wheels duration
timeout = tw.wheelDuration
}
// Find the next highest, rounding up
tick := int(((timeout - 1) / tw.tickDuration) + 1)
// Add another tick since the current tick may almost be over then map it to the wheel from our
// current position
tick += tw.current + 1
if tick >= tw.wheelLen {
tick -= tw.wheelLen
}
return tick
}
// advance will lock and move the wheel forward by proper number of ticks.
func (tw *TimerWheel) advance(now time.Time) {
if tw.lastTick == nil {
tw.lastTick = &now
}
// We want to round down
ticks := int(now.Sub(*tw.lastTick) / tw.tickDuration)
adv := ticks
if ticks > tw.wheelLen {
ticks = tw.wheelLen
}
for i := 0; i < ticks; i++ {
tw.current++
if tw.current >= tw.wheelLen {
tw.current = 0
}
if tw.wheel[tw.current].Head != nil {
// We need to append the expired items as to not starve evicting the oldest ones
if tw.expired.Tail == nil {
tw.expired.Head = tw.wheel[tw.current].Head
tw.expired.Tail = tw.wheel[tw.current].Tail
} else {
tw.expired.Tail.Next = tw.wheel[tw.current].Head
tw.expired.Tail = tw.wheel[tw.current].Tail
}
tw.wheel[tw.current].Head = nil
tw.wheel[tw.current].Tail = nil
}
}
// Advance the tick based on duration to avoid losing some accuracy
newTick := tw.lastTick.Add(tw.tickDuration * time.Duration(adv))
tw.lastTick = &newTick
}

196
timeout_system.go Normal file
View File

@ -0,0 +1,196 @@
package nebula
import (
"sync"
"time"
)
// How many timer objects should be cached
const systemTimerCacheMax = 50000
type SystemTimerWheel struct {
// Current tick
current int
// Cheat on finding the length of the wheel
wheelLen int
// Last time we ticked, since we are lazy ticking
lastTick *time.Time
// Durations of a tick and the entire wheel
tickDuration time.Duration
wheelDuration time.Duration
// The actual wheel which is just a set of singly linked lists, head/tail pointers
wheel []*SystemTimeoutList
// Singly linked list of items that have timed out of the wheel
expired *SystemTimeoutList
// Item cache to avoid garbage collect
itemCache *SystemTimeoutItem
itemsCached int
lock sync.Mutex
}
// Represents a tick in the wheel
type SystemTimeoutList struct {
Head *SystemTimeoutItem
Tail *SystemTimeoutItem
}
// Represents an item within a tick
type SystemTimeoutItem struct {
Item uint32
Next *SystemTimeoutItem
}
// Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
// Purge must be called once per entry to actually remove anything
func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel {
//TODO provide an error
//if min >= max {
// return nil
//}
// Round down and add 1 so we can have the smallest # of ticks in the wheel and still account for a full
// max duration
wLen := int((max / min) + 1)
tw := SystemTimerWheel{
wheelLen: wLen,
wheel: make([]*SystemTimeoutList, wLen),
tickDuration: min,
wheelDuration: max,
expired: &SystemTimeoutList{},
}
for i := range tw.wheel {
tw.wheel[i] = &SystemTimeoutList{}
}
return &tw
}
func (tw *SystemTimerWheel) Add(v uint32, timeout time.Duration) *SystemTimeoutItem {
tw.lock.Lock()
defer tw.lock.Unlock()
// Check and see if we should progress the tick
//tw.advance(time.Now())
i := tw.findWheel(timeout)
// Try to fetch off the cache
ti := tw.itemCache
if ti != nil {
tw.itemCache = ti.Next
ti.Next = nil
tw.itemsCached--
} else {
ti = &SystemTimeoutItem{}
}
// Relink and return
ti.Item = v
ti.Next = tw.wheel[i].Head
tw.wheel[i].Head = ti
if tw.wheel[i].Tail == nil {
tw.wheel[i].Tail = ti
}
return ti
}
func (tw *SystemTimerWheel) Purge() interface{} {
tw.lock.Lock()
defer tw.lock.Unlock()
if tw.expired.Head == nil {
return nil
}
ti := tw.expired.Head
tw.expired.Head = ti.Next
if tw.expired.Head == nil {
tw.expired.Tail = nil
}
p := ti.Item
// Clear out the items references
ti.Item = 0
ti.Next = nil
// Maybe cache it for later
if tw.itemsCached < systemTimerCacheMax {
ti.Next = tw.itemCache
tw.itemCache = ti
tw.itemsCached++
}
return p
}
func (tw *SystemTimerWheel) findWheel(timeout time.Duration) (i int) {
if timeout < tw.tickDuration {
// Can't track anything below the set resolution
timeout = tw.tickDuration
} else if timeout > tw.wheelDuration {
// We aren't handling timeouts greater than the wheels duration
timeout = tw.wheelDuration
}
// Find the next highest, rounding up
tick := int(((timeout - 1) / tw.tickDuration) + 1)
// Add another tick since the current tick may almost be over then map it to the wheel from our
// current position
tick += tw.current + 1
if tick >= tw.wheelLen {
tick -= tw.wheelLen
}
return tick
}
func (tw *SystemTimerWheel) advance(now time.Time) {
tw.lock.Lock()
defer tw.lock.Unlock()
if tw.lastTick == nil {
tw.lastTick = &now
}
// We want to round down
ticks := int(now.Sub(*tw.lastTick) / tw.tickDuration)
//l.Infoln("Ticks: ", ticks)
for i := 0; i < ticks; i++ {
tw.current++
//l.Infoln("Tick: ", tw.current)
if tw.current >= tw.wheelLen {
tw.current = 0
}
// We need to append the expired items as to not starve evicting the oldest ones
if tw.expired.Tail == nil {
tw.expired.Head = tw.wheel[tw.current].Head
tw.expired.Tail = tw.wheel[tw.current].Tail
} else {
tw.expired.Tail.Next = tw.wheel[tw.current].Head
if tw.wheel[tw.current].Tail != nil {
tw.expired.Tail = tw.wheel[tw.current].Tail
}
}
//l.Infoln("Head: ", tw.expired.Head, "Tail: ", tw.expired.Tail)
tw.wheel[tw.current].Head = nil
tw.wheel[tw.current].Tail = nil
tw.lastTick = &now
}
}

134
timeout_system_test.go Normal file
View File

@ -0,0 +1,134 @@
package nebula
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestNewSystemTimerWheel(t *testing.T) {
// Make sure we get an object we expect
tw := NewSystemTimerWheel(time.Second, time.Second*10)
assert.Equal(t, 11, tw.wheelLen)
assert.Equal(t, 0, tw.current)
assert.Nil(t, tw.lastTick)
assert.Equal(t, time.Second*1, tw.tickDuration)
assert.Equal(t, time.Second*10, tw.wheelDuration)
assert.Len(t, tw.wheel, 11)
// Assert the math is correct
tw = NewSystemTimerWheel(time.Second*3, time.Second*10)
assert.Equal(t, 4, tw.wheelLen)
tw = NewSystemTimerWheel(time.Second*120, time.Minute*10)
assert.Equal(t, 6, tw.wheelLen)
}
func TestSystemTimerWheel_findWheel(t *testing.T) {
tw := NewSystemTimerWheel(time.Second, time.Second*10)
assert.Len(t, tw.wheel, 11)
// Current + tick + 1 since we don't know how far into current we are
assert.Equal(t, 2, tw.findWheel(time.Second*1))
// Scale up to min duration
assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))
// Make sure we hit that last index
assert.Equal(t, 0, tw.findWheel(time.Second*10))
// Scale down to max duration
assert.Equal(t, 0, tw.findWheel(time.Second*11))
tw.current = 1
// Make sure we account for the current position properly
assert.Equal(t, 3, tw.findWheel(time.Second*1))
assert.Equal(t, 1, tw.findWheel(time.Second*10))
}
func TestSystemTimerWheel_Add(t *testing.T) {
tw := NewSystemTimerWheel(time.Second, time.Second*10)
fp1 := ip2int(net.ParseIP("1.2.3.4"))
tw.Add(fp1, time.Second*1)
// Make sure we set head and tail properly
assert.NotNil(t, tw.wheel[2])
assert.Equal(t, fp1, tw.wheel[2].Head.Item)
assert.Nil(t, tw.wheel[2].Head.Next)
assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
assert.Nil(t, tw.wheel[2].Tail.Next)
// Make sure we only modify head
fp2 := ip2int(net.ParseIP("1.2.3.4"))
tw.Add(fp2, time.Second*1)
assert.Equal(t, fp2, tw.wheel[2].Head.Item)
assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
assert.Nil(t, tw.wheel[2].Tail.Next)
// Make sure we use free'd items first
tw.itemCache = &SystemTimeoutItem{}
tw.itemsCached = 1
tw.Add(fp2, time.Second*1)
assert.Nil(t, tw.itemCache)
assert.Equal(t, 0, tw.itemsCached)
}
func TestSystemTimerWheel_Purge(t *testing.T) {
// First advance should set the lastTick and do nothing else
tw := NewSystemTimerWheel(time.Second, time.Second*10)
assert.Nil(t, tw.lastTick)
tw.advance(time.Now())
assert.NotNil(t, tw.lastTick)
assert.Equal(t, 0, tw.current)
fps := []uint32{9, 10, 11, 12}
//fp1 := ip2int(net.ParseIP("1.2.3.4"))
tw.Add(fps[0], time.Second*1)
tw.Add(fps[1], time.Second*1)
tw.Add(fps[2], time.Second*2)
tw.Add(fps[3], time.Second*2)
ta := time.Now().Add(time.Second * 3)
lastTick := *tw.lastTick
tw.advance(ta)
assert.Equal(t, 3, tw.current)
assert.True(t, tw.lastTick.After(lastTick))
// Make sure we get all 4 packets back
for i := 0; i < 4; i++ {
assert.Contains(t, fps, tw.Purge())
}
// Make sure there aren't any leftover
assert.Nil(t, tw.Purge())
assert.Nil(t, tw.expired.Head)
assert.Nil(t, tw.expired.Tail)
// Make sure we cached the free'd items
assert.Equal(t, 4, tw.itemsCached)
ci := tw.itemCache
for i := 0; i < 4; i++ {
assert.NotNil(t, ci)
ci = ci.Next
}
assert.Nil(t, ci)
// Lets make sure we roll over properly
ta = ta.Add(time.Second * 5)
tw.advance(ta)
assert.Equal(t, 8, tw.current)
ta = ta.Add(time.Second * 2)
tw.advance(ta)
assert.Equal(t, 10, tw.current)
ta = ta.Add(time.Second * 1)
tw.advance(ta)
assert.Equal(t, 0, tw.current)
}

138
timeout_test.go Normal file
View File

@ -0,0 +1,138 @@
package nebula
import (
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestNewTimerWheel(t *testing.T) {
// Make sure we get an object we expect
tw := NewTimerWheel(time.Second, time.Second*10)
assert.Equal(t, 11, tw.wheelLen)
assert.Equal(t, 0, tw.current)
assert.Nil(t, tw.lastTick)
assert.Equal(t, time.Second*1, tw.tickDuration)
assert.Equal(t, time.Second*10, tw.wheelDuration)
assert.Len(t, tw.wheel, 11)
// Assert the math is correct
tw = NewTimerWheel(time.Second*3, time.Second*10)
assert.Equal(t, 4, tw.wheelLen)
tw = NewTimerWheel(time.Second*120, time.Minute*10)
assert.Equal(t, 6, tw.wheelLen)
}
func TestTimerWheel_findWheel(t *testing.T) {
tw := NewTimerWheel(time.Second, time.Second*10)
assert.Len(t, tw.wheel, 11)
// Current + tick + 1 since we don't know how far into current we are
assert.Equal(t, 2, tw.findWheel(time.Second*1))
// Scale up to min duration
assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))
// Make sure we hit that last index
assert.Equal(t, 0, tw.findWheel(time.Second*10))
// Scale down to max duration
assert.Equal(t, 0, tw.findWheel(time.Second*11))
tw.current = 1
// Make sure we account for the current position properly
assert.Equal(t, 3, tw.findWheel(time.Second*1))
assert.Equal(t, 1, tw.findWheel(time.Second*10))
}
func TestTimerWheel_Add(t *testing.T) {
tw := NewTimerWheel(time.Second, time.Second*10)
fp1 := FirewallPacket{}
tw.Add(fp1, time.Second*1)
// Make sure we set head and tail properly
assert.NotNil(t, tw.wheel[2])
assert.Equal(t, fp1, tw.wheel[2].Head.Packet)
assert.Nil(t, tw.wheel[2].Head.Next)
assert.Equal(t, fp1, tw.wheel[2].Tail.Packet)
assert.Nil(t, tw.wheel[2].Tail.Next)
// Make sure we only modify head
fp2 := FirewallPacket{}
tw.Add(fp2, time.Second*1)
assert.Equal(t, fp2, tw.wheel[2].Head.Packet)
assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet)
assert.Equal(t, fp1, tw.wheel[2].Tail.Packet)
assert.Nil(t, tw.wheel[2].Tail.Next)
// Make sure we use free'd items first
tw.itemCache = &TimeoutItem{}
tw.itemsCached = 1
tw.Add(fp2, time.Second*1)
assert.Nil(t, tw.itemCache)
assert.Equal(t, 0, tw.itemsCached)
}
func TestTimerWheel_Purge(t *testing.T) {
// First advance should set the lastTick and do nothing else
tw := NewTimerWheel(time.Second, time.Second*10)
assert.Nil(t, tw.lastTick)
tw.advance(time.Now())
assert.NotNil(t, tw.lastTick)
assert.Equal(t, 0, tw.current)
fps := []FirewallPacket{
{LocalIP: 1},
{LocalIP: 2},
{LocalIP: 3},
{LocalIP: 4},
}
tw.Add(fps[0], time.Second*1)
tw.Add(fps[1], time.Second*1)
tw.Add(fps[2], time.Second*2)
tw.Add(fps[3], time.Second*2)
ta := time.Now().Add(time.Second * 3)
lastTick := *tw.lastTick
tw.advance(ta)
assert.Equal(t, 3, tw.current)
assert.True(t, tw.lastTick.After(lastTick))
// Make sure we get all 4 packets back
for i := 0; i < 4; i++ {
p, has := tw.Purge()
assert.True(t, has)
assert.Equal(t, fps[i], p)
}
// Make sure there aren't any leftover
_, ok := tw.Purge()
assert.False(t, ok)
assert.Nil(t, tw.expired.Head)
assert.Nil(t, tw.expired.Tail)
// Make sure we cached the free'd items
assert.Equal(t, 4, tw.itemsCached)
ci := tw.itemCache
for i := 0; i < 4; i++ {
assert.NotNil(t, ci)
ci = ci.Next
}
assert.Nil(t, ci)
// Lets make sure we roll over properly
ta = ta.Add(time.Second * 5)
tw.advance(ta)
assert.Equal(t, 8, tw.current)
ta = ta.Add(time.Second * 2)
tw.advance(ta)
assert.Equal(t, 10, tw.current)
ta = ta.Add(time.Second * 1)
tw.advance(ta)
assert.Equal(t, 0, tw.current)
}

108
tun_common.go Normal file
View File

@ -0,0 +1,108 @@
package nebula
import (
"fmt"
"net"
"strconv"
)
type route struct {
mtu int
route *net.IPNet
}
func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
var err error
r := config.Get("tun.routes")
if r == nil {
return []route{}, nil
}
rawRoutes, ok := r.([]interface{})
if !ok {
return nil, fmt.Errorf("tun.routes is not an array")
}
if len(rawRoutes) < 1 {
return []route{}, nil
}
routes := make([]route, len(rawRoutes))
for i, r := range rawRoutes {
m, ok := r.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1)
}
rMtu, ok := m["mtu"]
if !ok {
return nil, fmt.Errorf("entry %v.mtu in tun.routes is not present", i+1)
}
mtu, ok := rMtu.(int)
if !ok {
mtu, err = strconv.Atoi(rMtu.(string))
if err != nil {
return nil, fmt.Errorf("entry %v.mtu in tun.routes is not an integer: %v", i+1, err)
}
}
if mtu < 500 {
return nil, fmt.Errorf("entry %v.mtu in tun.routes is below 500: %v", i+1, mtu)
}
rRoute, ok := m["route"]
if !ok {
return nil, fmt.Errorf("entry %v.route in tun.routes is not present", i+1)
}
r := route{
mtu: mtu,
}
_, r.route, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
if err != nil {
return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
}
if !ipWithin(network, r.route) {
return nil, fmt.Errorf(
"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
i+1,
r.route.String(),
network.String(),
)
}
routes[i] = r
}
return routes, nil
}
func ipWithin(o *net.IPNet, i *net.IPNet) bool {
// Make sure o contains the lowest form of i
if !o.Contains(i.IP.Mask(i.Mask)) {
return false
}
// Find the max ip in i
ip4 := i.IP.To4()
if ip4 == nil {
return false
}
last := make(net.IP, len(ip4))
copy(last, ip4)
for x := range ip4 {
last[x] |= ^i.Mask[x]
}
// Make sure o contains the max
if !o.Contains(last) {
return false
}
return true
}

59
tun_darwin.go Normal file
View File

@ -0,0 +1,59 @@
package nebula
import (
"fmt"
"net"
"os/exec"
"strconv"
"github.com/songgao/water"
)
type Tun struct {
Device string
Cidr *net.IPNet
MTU int
*water.Interface
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in Darwin")
}
// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
return &Tun{
Cidr: cidr,
MTU: defaultMTU,
}, nil
}
func (c *Tun) Activate() error {
var err error
c.Interface, err = water.New(water.Config{
DeviceType: water.TUN,
})
if err != nil {
return fmt.Errorf("Activate failed: %v", err)
}
c.Device = c.Interface.Name()
// TODO use syscalls instead of exec.Command
if err = exec.Command("ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
if err = exec.Command("route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
if err = exec.Command("ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
return nil
}
func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}

249
tun_linux.go Normal file
View File

@ -0,0 +1,249 @@
package nebula
import (
"fmt"
"io"
"net"
"os"
"strings"
"syscall"
"unsafe"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
type Tun struct {
io.ReadWriteCloser
fd int
Device string
Cidr *net.IPNet
MaxMTU int
DefaultMTU int
TXQueueLen int
Routes []route
}
type ifReq struct {
Name [16]byte
Flags uint16
pad [8]byte
}
func ioctl(a1, a2, a3 uintptr) error {
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3)
if errno != 0 {
return errno
}
return nil
}
/*
func ipv4(addr string) (o [4]byte, err error) {
ip := net.ParseIP(addr).To4()
if ip == nil {
err = fmt.Errorf("failed to parse addr %s", addr)
return
}
for i, b := range ip {
o[i] = b
}
return
}
*/
const (
cIFF_TUN = 0x0001
cIFF_NO_PI = 0x1000
)
type ifreqAddr struct {
Name [16]byte
Addr syscall.RawSockaddrInet4
pad [8]byte
}
type ifreqMTU struct {
Name [16]byte
MTU int
pad [8]byte
}
type ifreqQLEN struct {
Name [16]byte
Value int
pad [8]byte
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
}
var req ifReq
req.Flags = uint16(cIFF_TUN | cIFF_NO_PI)
copy(req.Name[:], deviceName)
if err = ioctl(uintptr(fd), uintptr(syscall.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return
}
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
maxMTU := defaultMTU
for _, r := range routes {
if r.mtu > maxMTU {
maxMTU = r.mtu
}
}
ifce = &Tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: name,
Cidr: cidr,
MaxMTU: maxMTU,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
}
return
}
func (c *Tun) WriteRaw(b []byte) error {
var nn int
for {
max := len(b)
n, err := syscall.Write(c.fd, b[nn:max])
if n > 0 {
nn += n
}
if nn == len(b) {
return err
}
if err != nil {
return err
}
if n == 0 {
return io.ErrUnexpectedEOF
}
}
}
func (c Tun) deviceBytes() (o [16]byte) {
for i, c := range c.Device {
o[i] = byte(c)
}
return
}
func (c Tun) Activate() error {
devName := c.deviceBytes()
var addr, mask [4]byte
copy(addr[:], c.Cidr.IP.To4())
copy(mask[:], c.Cidr.Mask)
s, err := syscall.Socket(
syscall.AF_INET,
syscall.SOCK_DGRAM,
syscall.IPPROTO_IP,
)
if err != nil {
return err
}
fd := uintptr(s)
ifra := ifreqAddr{
Name: devName,
Addr: syscall.RawSockaddrInet4{
Family: syscall.AF_INET,
Addr: addr,
},
}
// Set the device ip address
if err = ioctl(fd, syscall.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
return err
}
// Set the device network
ifra.Addr.Addr = mask
if err = ioctl(fd, syscall.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
return err
}
// Set the device name
ifrf := ifReq{Name: devName}
if err = ioctl(fd, syscall.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return err
}
// Set the MTU on the device
ifm := ifreqMTU{Name: devName, MTU: c.MaxMTU}
if err = ioctl(fd, syscall.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
return err
}
// Set the transmit queue length
ifrq := ifreqQLEN{Name: devName, Value: c.TXQueueLen}
if err = ioctl(fd, syscall.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
return err
}
// Bring up the interface
ifrf.Flags = ifrf.Flags | syscall.IFF_UP
if err = ioctl(fd, syscall.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return err
}
// Set the routes
link, err := netlink.LinkByName(c.Device)
if err != nil {
return err
}
// Default route
dr := &net.IPNet{IP: c.Cidr.IP.Mask(c.Cidr.Mask), Mask: c.Cidr.Mask}
nr := netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: dr,
MTU: c.DefaultMTU,
Scope: unix.RT_SCOPE_LINK,
Src: c.Cidr.IP,
Protocol: unix.RTPROT_KERNEL,
Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST,
}
err = netlink.RouteReplace(&nr)
if err != nil {
return fmt.Errorf("failed to set mtu %v on the default route %v; %v", c.DefaultMTU, dr, err)
}
// Path routes
for _, r := range c.Routes {
nr := netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: r.route,
MTU: r.mtu,
Scope: unix.RT_SCOPE_LINK,
}
err = netlink.RouteAdd(&nr)
if err != nil {
return fmt.Errorf("failed to set mtu %v on route %v; %v", r.mtu, r.route, err)
}
}
// Run the interface
ifrf.Flags = ifrf.Flags | syscall.IFF_UP | syscall.IFF_RUNNING
if err = ioctl(fd, syscall.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return err
}
return nil
}

102
tun_test.go Normal file
View File

@ -0,0 +1,102 @@
package nebula
import (
"github.com/stretchr/testify/assert"
"net"
"testing"
)
func Test_parseRoutes(t *testing.T) {
c := NewConfig()
_, n, _ := net.ParseCIDR("10.0.0.0/24")
// test no routes config
routes, err := parseRoutes(c, n)
assert.Nil(t, err)
assert.Len(t, routes, 0)
// not an array
c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
routes, err = parseRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "tun.routes is not an array")
// no routes
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
routes, err = parseRoutes(c, n)
assert.Nil(t, err)
assert.Len(t, routes, 0)
// weird route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
routes, err = parseRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
// no mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
// bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
routes, err = parseRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
routes, err = parseRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
// missing route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
routes, err = parseRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
// unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
routes, err = parseRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope")
// below network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
routes, err = parseRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24")
// above network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
routes, err = parseRoutes(c, n)
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24")
// happy case
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
}}
routes, err = parseRoutes(c, n)
assert.Nil(t, err)
assert.Len(t, routes, 2)
tested := 0
for _, r := range routes {
if r.mtu == 8000 {
assert.Equal(t, "10.0.0.1/32", r.route.String())
tested++
} else {
assert.Equal(t, 9000, r.mtu)
assert.Equal(t, "10.0.0.0/29", r.route.String())
tested++
}
}
if tested != 2 {
t.Fatal("Did not see both routes")
}
}

72
tun_windows.go Normal file
View File

@ -0,0 +1,72 @@
package nebula
import (
"fmt"
"net"
"os/exec"
"github.com/songgao/water"
)
type Tun struct {
Device string
Cidr *net.IPNet
MTU int
*water.Interface
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in Windows")
}
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
return &Tun{
Cidr: cidr,
MTU: defaultMTU,
}, nil
}
func (c *Tun) Activate() error {
var err error
c.Interface, err = water.New(water.Config{
DeviceType: water.TUN,
PlatformSpecificParams: water.PlatformSpecificParams{
ComponentID: "tap0901",
Network: c.Cidr.String(),
},
})
if err != nil {
return fmt.Errorf("Activate failed: %v", err)
}
c.Device = c.Interface.Name()
// TODO use syscalls instead of exec.Command
err = exec.Command(
"netsh", "interface", "ipv4", "set", "address",
fmt.Sprintf("name=%s", c.Device),
"source=static",
fmt.Sprintf("addr=%s", c.Cidr.IP),
fmt.Sprintf("mask=%s", net.IP(c.Cidr.Mask)),
"gateway=none",
).Run()
if err != nil {
return fmt.Errorf("failed to run 'netsh' to set address: %s", err)
}
err = exec.Command(
"netsh", "interface", "ipv4", "set", "interface",
c.Device,
fmt.Sprintf("mtu=%d", c.MTU),
).Run()
if err != nil {
return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
}
return nil
}
func (c *Tun) WriteRaw(b []byte) error {
_, err := c.Write(b)
return err
}

34
udp_darwin.go Normal file
View File

@ -0,0 +1,34 @@
package nebula
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
import (
"fmt"
"net"
"syscall"
"golang.org/x/sys/unix"
)
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
if multi {
var controlErr error
err := c.Control(func(fd uintptr) {
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
return
}
})
if err != nil {
return err
}
if controlErr != nil {
return controlErr
}
}
return nil
},
}
}

123
udp_generic.go Normal file
View File

@ -0,0 +1,123 @@
// +build !linux
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
// means it can be used on platforms like Darwin and Windows.
package nebula
import (
"context"
"encoding/binary"
"fmt"
"net"
"strconv"
"strings"
)
type udpAddr struct {
net.UDPAddr
}
type udpConn struct {
*net.UDPConn
}
func NewUDPAddr(ip uint32, port uint16) *udpAddr {
return &udpAddr{
UDPAddr: net.UDPAddr{
IP: int2ip(ip),
Port: int(port),
},
}
}
func NewUDPAddrFromString(s string) *udpAddr {
p := strings.Split(s, ":")
if len(p) < 2 {
return nil
}
port, _ := strconv.Atoi(p[1])
return &udpAddr{
UDPAddr: net.UDPAddr{
IP: net.ParseIP(p[0]),
Port: port,
},
}
}
func NewListener(ip string, port int, multi bool) (*udpConn, error) {
lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp4", fmt.Sprintf("%s:%d", ip, port))
if err != nil {
return nil, err
}
if uc, ok := pc.(*net.UDPConn); ok {
return &udpConn{UDPConn: uc}, nil
}
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
}
func (ua *udpAddr) Equals(t *udpAddr) bool {
if t == nil || ua == nil {
return t == nil && ua == nil
}
return ua.IP.Equal(t.IP) && ua.Port == t.Port
}
func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
_, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr)
return err
}
func (uc *udpConn) LocalAddr() (*udpAddr, error) {
a := uc.UDPConn.LocalAddr()
switch v := a.(type) {
case *net.UDPAddr:
return &udpAddr{UDPAddr: *v}, nil
default:
return nil, fmt.Errorf("LocalAddr returned: %#v", a)
}
}
func (u *udpConn) reloadConfig(c *Config) {
// TODO
}
type rawMessage struct {
Len uint32
}
func (u *udpConn) ListenOut(f *Interface) {
plaintext := make([]byte, mtu)
buffer := make([]byte, mtu)
header := &Header{}
fwPacket := &FirewallPacket{}
udpAddr := &udpAddr{}
nb := make([]byte, 12, 12)
for {
// Just read one packet at a time
n, rua, err := u.ReadFromUDP(buffer)
if err != nil {
l.WithError(err).Error("Failed to read packets")
continue
}
udpAddr.UDPAddr = *rua
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, nb)
}
}
func udp2ip(addr *udpAddr) net.IP {
return addr.IP
}
func udp2ipInt(addr *udpAddr) uint32 {
return binary.BigEndian.Uint32(addr.IP.To4())
}
func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
return !addr.Equals(newaddr)
}

315
udp_linux.go Normal file
View File

@ -0,0 +1,315 @@
package nebula
import (
"encoding/binary"
"encoding/json"
"fmt"
"net"
"strconv"
"strings"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
//TODO: make it support reload as best you can!
type udpConn struct {
sysFd int
}
type udpAddr struct {
IP uint32
Port uint16
}
func NewUDPAddr(ip uint32, port uint16) *udpAddr {
return &udpAddr{IP: ip, Port: port}
}
func NewUDPAddrFromString(s string) *udpAddr {
p := strings.Split(s, ":")
if len(p) < 2 {
return nil
}
port, _ := strconv.Atoi(p[1])
return &udpAddr{
IP: ip2int(net.ParseIP(p[0])),
Port: uint16(port),
}
}
type rawSockaddr struct {
Family uint16
Data [14]uint8
}
type rawSockaddrAny struct {
Addr rawSockaddr
Pad [96]int8
}
var x int
func NewListener(ip string, port int, multi bool) (*udpConn, error) {
syscall.ForkLock.RLock()
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
if err == nil {
syscall.CloseOnExec(fd)
}
syscall.ForkLock.RUnlock()
if err != nil {
syscall.Close(fd)
return nil, err
}
var lip [4]byte
copy(lip[:], net.ParseIP(ip).To4())
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, 0x0F, 1); err != nil {
return nil, err
}
if err = syscall.Bind(fd, &syscall.SockaddrInet4{Port: port}); err != nil {
return nil, err
}
// SO_REUSEADDR does not load balance so we use PORT
if multi {
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
return nil, err
}
}
//TODO: this may be useful for forcing threads into specific cores
//syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, unix.SO_INCOMING_CPU, x)
//v, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, unix.SO_INCOMING_CPU)
//l.Println(v, err)
return &udpConn{sysFd: fd}, err
}
func (u *udpConn) SetRecvBuffer(n int) error {
return syscall.SetsockoptInt(u.sysFd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, n)
}
func (u *udpConn) SetSendBuffer(n int) error {
return syscall.SetsockoptInt(u.sysFd, syscall.SOL_SOCKET, syscall.SO_SNDBUFFORCE, n)
}
func (u *udpConn) GetRecvBuffer() (int, error) {
return syscall.GetsockoptInt(int(u.sysFd), syscall.SOL_SOCKET, syscall.SO_RCVBUF)
}
func (u *udpConn) GetSendBuffer() (int, error) {
return syscall.GetsockoptInt(int(u.sysFd), syscall.SOL_SOCKET, syscall.SO_SNDBUF)
}
func (u *udpConn) LocalAddr() (*udpAddr, error) {
var rsa rawSockaddrAny
var rLen = syscall.SizeofSockaddrAny
_, _, err := syscall.Syscall(
syscall.SYS_GETSOCKNAME,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&rsa)),
uintptr(unsafe.Pointer(&rLen)),
)
if err != 0 {
return nil, err
}
addr := &udpAddr{}
if rsa.Addr.Family == syscall.AF_INET {
addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1])
addr.IP = uint32(rsa.Addr.Data[2])<<24 + uint32(rsa.Addr.Data[3])<<16 + uint32(rsa.Addr.Data[4])<<8 + uint32(rsa.Addr.Data[5])
} else {
addr.Port = 0
addr.IP = 0
}
return addr, nil
}
func (u *udpConn) ListenOut(f *Interface) {
plaintext := make([]byte, mtu)
header := &Header{}
fwPacket := &FirewallPacket{}
udpAddr := &udpAddr{}
nb := make([]byte, 12, 12)
//TODO: should we track this?
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize)
for {
n, err := u.ReadMulti(msgs)
if err != nil {
l.WithError(err).Error("Failed to read packets")
continue
}
//metric.Update(int64(n))
for i := 0; i < n; i++ {
udpAddr.IP = binary.BigEndian.Uint32(names[i][4:8])
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, nb)
}
}
}
func (u *udpConn) Read(addr *udpAddr, b []byte) ([]byte, error) {
var rsa rawSockaddrAny
var rLen = syscall.SizeofSockaddrAny
for {
n, _, err := syscall.Syscall6(
syscall.SYS_RECVFROM,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&b[0])),
uintptr(len(b)),
uintptr(0),
uintptr(unsafe.Pointer(&rsa)),
uintptr(unsafe.Pointer(&rLen)),
)
if err != 0 {
return nil, &net.OpError{Op: "read", Err: err}
}
if rsa.Addr.Family == syscall.AF_INET {
addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1])
addr.IP = uint32(rsa.Addr.Data[2])<<24 + uint32(rsa.Addr.Data[3])<<16 + uint32(rsa.Addr.Data[4])<<8 + uint32(rsa.Addr.Data[5])
} else {
addr.Port = 0
addr.IP = 0
}
return b[:n], nil
}
}
func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) {
for {
n, _, err := syscall.Syscall6(
syscall.SYS_RECVMMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&msgs[0])),
uintptr(len(msgs)),
unix.MSG_WAITFORONE,
0,
0,
)
if err != 0 {
return 0, &net.OpError{Op: "recvmmsg", Err: err}
}
return int(n), nil
}
}
func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
var rsa syscall.RawSockaddrInet4
//TODO: sometimes addr is nil!
rsa.Family = syscall.AF_INET
p := (*[2]byte)(unsafe.Pointer(&rsa.Port))
p[0] = byte(addr.Port >> 8)
p[1] = byte(addr.Port)
rsa.Addr[0] = byte(addr.IP & 0xff000000 >> 24)
rsa.Addr[1] = byte(addr.IP & 0x00ff0000 >> 16)
rsa.Addr[2] = byte(addr.IP & 0x0000ff00 >> 8)
rsa.Addr[3] = byte(addr.IP & 0x000000ff)
for {
_, _, err := syscall.Syscall6(
syscall.SYS_SENDTO,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&b[0])),
uintptr(len(b)),
uintptr(0),
uintptr(unsafe.Pointer(&rsa)),
uintptr(syscall.SizeofSockaddrInet4),
)
if err != 0 {
return &net.OpError{Op: "sendto", Err: err}
}
//TODO: handle incomplete writes
return nil
}
}
func (u *udpConn) reloadConfig(c *Config) {
b := c.GetInt("listen.read_buffer", 0)
if b > 0 {
err := u.SetRecvBuffer(b)
if err == nil {
s, err := u.GetRecvBuffer()
if err == nil {
l.WithField("size", s).Info("listen.read_buffer was set")
} else {
l.WithError(err).Warn("Failed to get listen.read_buffer")
}
} else {
l.WithError(err).Error("Failed to set listen.read_buffer")
}
}
b = c.GetInt("listen.write_buffer", 0)
if b > 0 {
err := u.SetSendBuffer(b)
if err == nil {
s, err := u.GetSendBuffer()
if err == nil {
l.WithField("size", s).Info("listen.write_buffer was set")
} else {
l.WithError(err).Warn("Failed to get listen.write_buffer")
}
} else {
l.WithError(err).Error("Failed to set listen.write_buffer")
}
}
}
func (ua *udpAddr) Equals(t *udpAddr) bool {
if t == nil || ua == nil {
return t == nil && ua == nil
}
return ua.IP == t.IP && ua.Port == t.Port
}
func (ua *udpAddr) Copy() *udpAddr {
return &udpAddr{
Port: ua.Port,
IP: ua.IP,
}
}
func (ua *udpAddr) String() string {
return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port)
}
func (ua *udpAddr) MarshalJSON() ([]byte, error) {
return json.Marshal(m{"ip": int2ip(ua.IP), "port": ua.Port})
}
func udp2ip(addr *udpAddr) net.IP {
return int2ip(addr.IP)
}
func udp2ipInt(addr *udpAddr) uint32 {
return addr.IP
}
func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
return !addr.Equals(newaddr)
}

50
udp_linux_amd64.go Normal file
View File

@ -0,0 +1,50 @@
package nebula
import "unsafe"
type iovec struct {
Base *byte
Len uint64
}
type msghdr struct {
Name *byte
Namelen uint32
Pad0 [4]byte
Iov *iovec
Iovlen uint64
Control *byte
Controllen uint64
Flags int32
Pad1 [4]byte
}
type rawMessage struct {
Hdr msghdr
Len uint32
Pad0 [4]byte
}
func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)
for i := range msgs {
buffers[i] = make([]byte, mtu)
names[i] = make([]byte, 0x1c) //TODO = sizeofSockaddrInet6
//TODO: this is still silly, no need for an array
vs := []iovec{
{Base: (*byte)(unsafe.Pointer(&buffers[i][0])), Len: uint64(len(buffers[i]))},
}
msgs[i].Hdr.Iov = &vs[0]
msgs[i].Hdr.Iovlen = uint64(len(vs))
msgs[i].Hdr.Name = (*byte)(unsafe.Pointer(&names[i][0]))
msgs[i].Hdr.Namelen = uint32(len(names[i]))
}
return msgs, buffers, names
}

Some files were not shown because too many files have changed in this diff Show More