mirror of https://github.com/slackhq/nebula.git
Move util to test, contextual errors to util (#575)
This commit is contained in:
parent
19a9a4221e
commit
4453964e34
|
@ -7,12 +7,12 @@ import (
|
|||
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewAllowListFromConfig(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"192.168.0.0": true,
|
||||
|
|
10
bits_test.go
10
bits_test.go
|
@ -3,12 +3,12 @@ package nebula
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBits(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
b := NewBits(10)
|
||||
|
||||
// make sure it is the right size
|
||||
|
@ -76,7 +76,7 @@ func TestBits(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBitsDupeCounter(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
b := NewBits(10)
|
||||
b.lostCounter.Clear()
|
||||
b.dupeCounter.Clear()
|
||||
|
@ -101,7 +101,7 @@ func TestBitsDupeCounter(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBitsOutOfWindowCounter(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
b := NewBits(10)
|
||||
b.lostCounter.Clear()
|
||||
b.dupeCounter.Clear()
|
||||
|
@ -131,7 +131,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBitsLostCounter(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
b := NewBits(10)
|
||||
b.lostCounter.Clear()
|
||||
b.dupeCounter.Clear()
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
|
@ -752,7 +752,7 @@ func TestNebulaCertificate_Copy(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
cc := c.Copy()
|
||||
|
||||
util.AssertDeepCopyEqual(t, c, cc)
|
||||
test.AssertDeepCopyEqual(t, c, cc)
|
||||
}
|
||||
|
||||
func TestUnmarshalNebulaCertificate(t *testing.T) {
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
// A version string that can be set with
|
||||
|
@ -60,7 +61,7 @@ func main() {
|
|||
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||
|
||||
switch v := err.(type) {
|
||||
case nebula.ContextualError:
|
||||
case util.ContextualError:
|
||||
v.Log(l)
|
||||
os.Exit(1)
|
||||
case error:
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
// A version string that can be set with
|
||||
|
@ -54,7 +55,7 @@ func main() {
|
|||
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||
|
||||
switch v := err.(type) {
|
||||
case nebula.ContextualError:
|
||||
case util.ContextualError:
|
||||
v.Log(l)
|
||||
os.Exit(1)
|
||||
case error:
|
||||
|
|
|
@ -7,12 +7,12 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfig_Load(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
dir, err := ioutil.TempDir("", "config-test")
|
||||
// invalid yaml
|
||||
c := NewC(l)
|
||||
|
@ -42,7 +42,7 @@ func TestConfig_Load(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestConfig_Get(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
// test simple type
|
||||
c := NewC(l)
|
||||
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
|
||||
|
@ -58,14 +58,14 @@ func TestConfig_Get(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestConfig_GetStringSlice(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
c := NewC(l)
|
||||
c.Settings["slice"] = []interface{}{"one", "two"}
|
||||
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
|
||||
}
|
||||
|
||||
func TestConfig_GetBool(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
c := NewC(l)
|
||||
c.Settings["bool"] = true
|
||||
assert.Equal(t, true, c.GetBool("bool", false))
|
||||
|
@ -93,7 +93,7 @@ func TestConfig_GetBool(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestConfig_HasChanged(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
// No reload has occurred, return false
|
||||
c := NewC(l)
|
||||
c.Settings["test"] = "hi"
|
||||
|
@ -115,7 +115,7 @@ func TestConfig_HasChanged(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestConfig_ReloadConfig(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
done := make(chan bool, 1)
|
||||
dir, err := ioutil.TempDir("", "config-test")
|
||||
assert.Nil(t, err)
|
||||
|
|
|
@ -11,15 +11,15 @@ import (
|
|||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var vpnIp iputil.VpnIp
|
||||
|
||||
func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
|
@ -89,7 +89,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
|
@ -164,7 +164,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||
// Disconnect only if disconnectInvalid: true is set.
|
||||
func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||
now := time.Now()
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
ipNet := net.IPNet{
|
||||
IP: net.IPv4(172, 1, 1, 2),
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
|
|
|
@ -9,13 +9,13 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
||||
// To properly ensure we are not exposing core memory to the caller
|
||||
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
|
||||
|
@ -94,7 +94,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||
|
||||
// Make sure we don't have any unexpected fields
|
||||
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
|
||||
util.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||
|
||||
// Make sure we don't panic if the host info doesn't have a cert yet
|
||||
assert.NotPanics(t, func() {
|
||||
|
|
|
@ -14,12 +14,12 @@ import (
|
|||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewFirewall(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
c := &cert.NebulaCertificate{}
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
conntrack := fw.Conntrack
|
||||
|
@ -58,7 +58,7 @@ func TestNewFirewall(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFirewall_AddRule(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
|
@ -133,7 +133,7 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFirewall_Drop(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
|
@ -308,7 +308,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||
}
|
||||
|
||||
func TestFirewall_Drop2(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
|
@ -367,7 +367,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFirewall_Drop3(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
|
@ -453,7 +453,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
|
@ -635,7 +635,7 @@ func Test_parsePort(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNewFirewallFromConfig(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
// Test a bad rule definition
|
||||
c := &cert.NebulaCertificate{}
|
||||
conf := config.NewC(l)
|
||||
|
@ -685,7 +685,7 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
// Test adding tcp rule
|
||||
conf := config.NewC(l)
|
||||
mf := &mockFirewall{}
|
||||
|
@ -849,7 +849,7 @@ func TestTCPRTTTracking(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFirewall_convertRule(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
|
|
|
@ -7,13 +7,13 @@ import (
|
|||
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
|
@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_NewHandshakeManagerTrigger(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
|
|
|
@ -8,8 +8,8 @@ import (
|
|||
"github.com/golang/protobuf/proto"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -46,7 +46,7 @@ func TestNewLhQuery(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_lhStaticMapping(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
lh1 := "10.128.0.2"
|
||||
lh1IP := net.ParseIP(lh1)
|
||||
|
||||
|
@ -67,7 +67,7 @@ func Test_lhStaticMapping(t *testing.T) {
|
|||
}
|
||||
|
||||
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
lh1 := "10.128.0.2"
|
||||
lh1IP := net.ParseIP(lh1)
|
||||
|
||||
|
@ -137,7 +137,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||
}
|
||||
|
||||
func TestLighthouse_Memory(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
|
||||
myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
|
||||
myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
|
||||
|
@ -266,7 +266,7 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr,
|
|||
|
||||
//TODO: this is a RemoteList test
|
||||
//func Test_lhRemoteAllowList(t *testing.T) {
|
||||
// l := NewTestLogger()
|
||||
// l := NewLogger()
|
||||
// c := NewConfig(l)
|
||||
// c.Settings["remoteallowlist"] = map[interface{}]interface{}{
|
||||
// "10.20.0.0/12": false,
|
||||
|
|
33
logger.go
33
logger.go
|
@ -1,7 +1,6 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -10,38 +9,6 @@ import (
|
|||
"github.com/slackhq/nebula/config"
|
||||
)
|
||||
|
||||
type ContextualError struct {
|
||||
RealError error
|
||||
Fields map[string]interface{}
|
||||
Context string
|
||||
}
|
||||
|
||||
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
|
||||
return ContextualError{Context: msg, Fields: fields, RealError: realError}
|
||||
}
|
||||
|
||||
func (ce ContextualError) Error() string {
|
||||
if ce.RealError == nil {
|
||||
return ce.Context
|
||||
}
|
||||
return ce.RealError.Error()
|
||||
}
|
||||
|
||||
func (ce ContextualError) Unwrap() error {
|
||||
if ce.RealError == nil {
|
||||
return errors.New(ce.Context)
|
||||
}
|
||||
return ce.RealError
|
||||
}
|
||||
|
||||
func (ce *ContextualError) Log(lr *logrus.Logger) {
|
||||
if ce.RealError != nil {
|
||||
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
|
||||
} else {
|
||||
lr.WithFields(ce.Fields).Error(ce.Context)
|
||||
}
|
||||
}
|
||||
|
||||
func configLogger(l *logrus.Logger, c *config.C) error {
|
||||
// set up our logging level
|
||||
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
|
||||
|
|
43
main.go
43
main.go
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/sshd"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
|
@ -44,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
|
||||
err := configLogger(l, c)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to configure the logger", nil, err)
|
||||
return nil, util.NewContextualError("Failed to configure the logger", nil, err)
|
||||
}
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
|
@ -57,20 +58,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
caPool, err := loadCAFromConfig(l, c)
|
||||
if err != nil {
|
||||
//The errors coming out of loadCA are already nicely formatted
|
||||
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
||||
return nil, util.NewContextualError("Failed to load ca from config", nil, err)
|
||||
}
|
||||
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||
|
||||
cs, err := NewCertStateFromConfig(c)
|
||||
if err != nil {
|
||||
//The errors coming out of NewCertStateFromConfig are already nicely formatted
|
||||
return nil, NewContextualError("Failed to load certificate from config", nil, err)
|
||||
return nil, util.NewContextualError("Failed to load certificate from config", nil, err)
|
||||
}
|
||||
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
||||
|
||||
fw, err := NewFirewallFromConfig(l, cs.certificate, c)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Error while loading firewall rules", nil, err)
|
||||
return nil, util.NewContextualError("Error while loading firewall rules", nil, err)
|
||||
}
|
||||
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
|
||||
|
||||
|
@ -78,11 +79,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
tunCidr := cs.certificate.Details.Ips[0]
|
||||
routes, err := parseRoutes(c, tunCidr)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Could not parse tun.routes", nil, err)
|
||||
return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
|
||||
}
|
||||
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||
return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||
}
|
||||
|
||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||
|
@ -91,7 +92,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
if c.GetBool("sshd.enabled", false) {
|
||||
sshStart, err = configSSH(l, ssh, c)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Error while configuring the sshd", nil, err)
|
||||
return nil, util.NewContextualError("Error while configuring the sshd", nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -167,7 +168,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to get a tun/tap device", nil, err)
|
||||
return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -185,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
for i := 0; i < routines; i++ {
|
||||
udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64))
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||
}
|
||||
udpServer.ReloadConfig(c)
|
||||
udpConns[i] = udpServer
|
||||
|
@ -194,7 +195,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
if port == 0 {
|
||||
uPort, err := udpServer.LocalAddr()
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to get listening port", nil, err)
|
||||
return nil, util.NewContextualError("Failed to get listening port", nil, err)
|
||||
}
|
||||
port = int(uPort.Port)
|
||||
}
|
||||
|
@ -209,7 +210,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
for _, rawPreferredRange := range rawPreferredRanges {
|
||||
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to parse preferred ranges", nil, err)
|
||||
return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err)
|
||||
}
|
||||
preferredRanges = append(preferredRanges, preferredRange)
|
||||
}
|
||||
|
@ -222,7 +223,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
if rawLocalRange != "" {
|
||||
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to parse local_range", nil, err)
|
||||
return nil, util.NewContextualError("Failed to parse local_range", nil, err)
|
||||
}
|
||||
|
||||
// Check if the entry for local_range was already specified in
|
||||
|
@ -261,7 +262,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
|
||||
// fatal if am_lighthouse is enabled but we are using an ephemeral port
|
||||
if amLighthouse && (c.GetInt("listen.port", 0) == 0) {
|
||||
return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
|
||||
return nil, util.NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
|
||||
}
|
||||
|
||||
// warn if am_lighthouse is enabled but upstream lighthouses exists
|
||||
|
@ -274,10 +275,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
for i, host := range rawLighthouseHosts {
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
|
||||
return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
|
||||
}
|
||||
if !tunCidr.Contains(ip) {
|
||||
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
||||
return nil, util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
||||
}
|
||||
lighthouseHosts[i] = iputil.Ip2VpnIp(ip)
|
||||
}
|
||||
|
@ -298,13 +299,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
|
||||
remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
|
||||
return nil, util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
|
||||
}
|
||||
lightHouse.SetRemoteAllowList(remoteAllowList)
|
||||
|
||||
localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list")
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
|
||||
return nil, util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
|
||||
}
|
||||
lightHouse.SetLocalAllowList(localAllowList)
|
||||
|
||||
|
@ -313,21 +314,21 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
ip := net.ParseIP(fmt.Sprintf("%v", k))
|
||||
vpnIp := iputil.Ip2VpnIp(ip)
|
||||
if !tunCidr.Contains(ip) {
|
||||
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
|
||||
return nil, util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
|
||||
}
|
||||
vals, ok := v.([]interface{})
|
||||
if ok {
|
||||
for _, v := range vals {
|
||||
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||
return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||
}
|
||||
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
|
||||
}
|
||||
} else {
|
||||
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||
return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||
}
|
||||
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
|
||||
}
|
||||
|
@ -426,7 +427,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
statsStart, err := startStats(l, c, buildVersion, configTest)
|
||||
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to start stats emitter", nil, err)
|
||||
return nil, util.NewContextualError("Failed to start stats emitter", nil, err)
|
||||
}
|
||||
|
||||
if configTest {
|
||||
|
|
|
@ -5,12 +5,12 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewPunchyFromConfig(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
|
||||
// Test defaults
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package util
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
|
@ -1,4 +1,4 @@
|
|||
package util
|
||||
package test
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
|
@ -7,7 +7,7 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func NewTestLogger() *logrus.Logger {
|
||||
func NewLogger() *logrus.Logger {
|
||||
l := logrus.New()
|
||||
|
||||
v := os.Getenv("TEST_LOGS")
|
|
@ -6,12 +6,12 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseRoutes(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
||||
|
||||
|
@ -107,7 +107,7 @@ func Test_parseRoutes(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_parseUnsafeRoutes(t *testing.T) {
|
||||
l := util.NewTestLogger()
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
||||
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type ContextualError struct {
|
||||
RealError error
|
||||
Fields map[string]interface{}
|
||||
Context string
|
||||
}
|
||||
|
||||
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
|
||||
return ContextualError{Context: msg, Fields: fields, RealError: realError}
|
||||
}
|
||||
|
||||
func (ce ContextualError) Error() string {
|
||||
if ce.RealError == nil {
|
||||
return ce.Context
|
||||
}
|
||||
return ce.RealError.Error()
|
||||
}
|
||||
|
||||
func (ce ContextualError) Unwrap() error {
|
||||
if ce.RealError == nil {
|
||||
return errors.New(ce.Context)
|
||||
}
|
||||
return ce.RealError
|
||||
}
|
||||
|
||||
func (ce *ContextualError) Log(lr *logrus.Logger) {
|
||||
if ce.RealError != nil {
|
||||
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
|
||||
} else {
|
||||
lr.WithFields(ce.Fields).Error(ce.Context)
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package nebula
|
||||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
@ -8,6 +8,8 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type m map[string]interface{}
|
||||
|
||||
type TestLogWriter struct {
|
||||
Logs []string
|
||||
}
|
Loading…
Reference in New Issue