mirror of https://github.com/slackhq/nebula.git
Move util to test, contextual errors to util (#575)
This commit is contained in:
parent
19a9a4221e
commit
4453964e34
|
@ -7,12 +7,12 @@ import (
|
||||||
|
|
||||||
"github.com/slackhq/nebula/cidr"
|
"github.com/slackhq/nebula/cidr"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewAllowListFromConfig(t *testing.T) {
|
func TestNewAllowListFromConfig(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||||
"192.168.0.0": true,
|
"192.168.0.0": true,
|
||||||
|
|
10
bits_test.go
10
bits_test.go
|
@ -3,12 +3,12 @@ package nebula
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBits(t *testing.T) {
|
func TestBits(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(10)
|
||||||
|
|
||||||
// make sure it is the right size
|
// make sure it is the right size
|
||||||
|
@ -76,7 +76,7 @@ func TestBits(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsDupeCounter(t *testing.T) {
|
func TestBitsDupeCounter(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(10)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
|
@ -101,7 +101,7 @@ func TestBitsDupeCounter(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsOutOfWindowCounter(t *testing.T) {
|
func TestBitsOutOfWindowCounter(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(10)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
|
@ -131,7 +131,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitsLostCounter(t *testing.T) {
|
func TestBitsLostCounter(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
b := NewBits(10)
|
b := NewBits(10)
|
||||||
b.lostCounter.Clear()
|
b.lostCounter.Clear()
|
||||||
b.dupeCounter.Clear()
|
b.dupeCounter.Clear()
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/crypto/curve25519"
|
"golang.org/x/crypto/curve25519"
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
|
@ -752,7 +752,7 @@ func TestNebulaCertificate_Copy(t *testing.T) {
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
cc := c.Copy()
|
cc := c.Copy()
|
||||||
|
|
||||||
util.AssertDeepCopyEqual(t, c, cc)
|
test.AssertDeepCopyEqual(t, c, cc)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalNebulaCertificate(t *testing.T) {
|
func TestUnmarshalNebulaCertificate(t *testing.T) {
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A version string that can be set with
|
// A version string that can be set with
|
||||||
|
@ -60,7 +61,7 @@ func main() {
|
||||||
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||||
|
|
||||||
switch v := err.(type) {
|
switch v := err.(type) {
|
||||||
case nebula.ContextualError:
|
case util.ContextualError:
|
||||||
v.Log(l)
|
v.Log(l)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
case error:
|
case error:
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A version string that can be set with
|
// A version string that can be set with
|
||||||
|
@ -54,7 +55,7 @@ func main() {
|
||||||
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
|
||||||
|
|
||||||
switch v := err.(type) {
|
switch v := err.(type) {
|
||||||
case nebula.ContextualError:
|
case util.ContextualError:
|
||||||
v.Log(l)
|
v.Log(l)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
case error:
|
case error:
|
||||||
|
|
|
@ -7,12 +7,12 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfig_Load(t *testing.T) {
|
func TestConfig_Load(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
dir, err := ioutil.TempDir("", "config-test")
|
dir, err := ioutil.TempDir("", "config-test")
|
||||||
// invalid yaml
|
// invalid yaml
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
|
@ -42,7 +42,7 @@ func TestConfig_Load(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_Get(t *testing.T) {
|
func TestConfig_Get(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
// test simple type
|
// test simple type
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
|
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
|
||||||
|
@ -58,14 +58,14 @@ func TestConfig_Get(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_GetStringSlice(t *testing.T) {
|
func TestConfig_GetStringSlice(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
c.Settings["slice"] = []interface{}{"one", "two"}
|
c.Settings["slice"] = []interface{}{"one", "two"}
|
||||||
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
|
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_GetBool(t *testing.T) {
|
func TestConfig_GetBool(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
c.Settings["bool"] = true
|
c.Settings["bool"] = true
|
||||||
assert.Equal(t, true, c.GetBool("bool", false))
|
assert.Equal(t, true, c.GetBool("bool", false))
|
||||||
|
@ -93,7 +93,7 @@ func TestConfig_GetBool(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_HasChanged(t *testing.T) {
|
func TestConfig_HasChanged(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
// No reload has occurred, return false
|
// No reload has occurred, return false
|
||||||
c := NewC(l)
|
c := NewC(l)
|
||||||
c.Settings["test"] = "hi"
|
c.Settings["test"] = "hi"
|
||||||
|
@ -115,7 +115,7 @@ func TestConfig_HasChanged(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_ReloadConfig(t *testing.T) {
|
func TestConfig_ReloadConfig(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
done := make(chan bool, 1)
|
done := make(chan bool, 1)
|
||||||
dir, err := ioutil.TempDir("", "config-test")
|
dir, err := ioutil.TempDir("", "config-test")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
|
@ -11,15 +11,15 @@ import (
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/slackhq/nebula/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var vpnIp iputil.VpnIp
|
var vpnIp iputil.VpnIp
|
||||||
|
|
||||||
func Test_NewConnectionManagerTest(t *testing.T) {
|
func Test_NewConnectionManagerTest(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
|
@ -89,7 +89,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_NewConnectionManagerTest2(t *testing.T) {
|
func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
|
@ -164,7 +164,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||||
// Disconnect only if disconnectInvalid: true is set.
|
// Disconnect only if disconnectInvalid: true is set.
|
||||||
func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
ipNet := net.IPNet{
|
ipNet := net.IPNet{
|
||||||
IP: net.IPv4(172, 1, 1, 2),
|
IP: net.IPv4(172, 1, 1, 2),
|
||||||
Mask: net.IPMask{255, 255, 255, 0},
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
|
|
|
@ -9,13 +9,13 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/slackhq/nebula/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
||||||
// To properly ensure we are not exposing core memory to the caller
|
// To properly ensure we are not exposing core memory to the caller
|
||||||
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
|
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
|
||||||
|
@ -94,7 +94,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||||
|
|
||||||
// Make sure we don't have any unexpected fields
|
// Make sure we don't have any unexpected fields
|
||||||
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
|
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
|
||||||
util.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||||
|
|
||||||
// Make sure we don't panic if the host info doesn't have a cert yet
|
// Make sure we don't panic if the host info doesn't have a cert yet
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
|
|
|
@ -14,12 +14,12 @@ import (
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewFirewall(t *testing.T) {
|
func TestNewFirewall(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
c := &cert.NebulaCertificate{}
|
c := &cert.NebulaCertificate{}
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||||
conntrack := fw.Conntrack
|
conntrack := fw.Conntrack
|
||||||
|
@ -58,7 +58,7 @@ func TestNewFirewall(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_AddRule(t *testing.T) {
|
func TestFirewall_AddRule(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
|
@ -133,7 +133,7 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop(t *testing.T) {
|
func TestFirewall_Drop(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
|
@ -308,7 +308,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop2(t *testing.T) {
|
func TestFirewall_Drop2(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
|
@ -367,7 +367,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3(t *testing.T) {
|
func TestFirewall_Drop3(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
|
@ -453,7 +453,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
|
@ -635,7 +635,7 @@ func Test_parsePort(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewFirewallFromConfig(t *testing.T) {
|
func TestNewFirewallFromConfig(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
// Test a bad rule definition
|
// Test a bad rule definition
|
||||||
c := &cert.NebulaCertificate{}
|
c := &cert.NebulaCertificate{}
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(l)
|
||||||
|
@ -685,7 +685,7 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
// Test adding tcp rule
|
// Test adding tcp rule
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(l)
|
||||||
mf := &mockFirewall{}
|
mf := &mockFirewall{}
|
||||||
|
@ -849,7 +849,7 @@ func TestTCPRTTTracking(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_convertRule(t *testing.T) {
|
func TestFirewall_convertRule(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
|
|
|
@ -7,13 +7,13 @@ import (
|
||||||
|
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/slackhq/nebula/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
|
@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_NewHandshakeManagerTrigger(t *testing.T) {
|
func Test_NewHandshakeManagerTrigger(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||||
|
|
|
@ -8,8 +8,8 @@ import (
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
"github.com/slackhq/nebula/util"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ func TestNewLhQuery(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_lhStaticMapping(t *testing.T) {
|
func Test_lhStaticMapping(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
lh1 := "10.128.0.2"
|
lh1 := "10.128.0.2"
|
||||||
lh1IP := net.ParseIP(lh1)
|
lh1IP := net.ParseIP(lh1)
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ func Test_lhStaticMapping(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
lh1 := "10.128.0.2"
|
lh1 := "10.128.0.2"
|
||||||
lh1IP := net.ParseIP(lh1)
|
lh1IP := net.ParseIP(lh1)
|
||||||
|
|
||||||
|
@ -137,7 +137,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLighthouse_Memory(t *testing.T) {
|
func TestLighthouse_Memory(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
|
|
||||||
myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
|
myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
|
||||||
myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
|
myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
|
||||||
|
@ -266,7 +266,7 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr,
|
||||||
|
|
||||||
//TODO: this is a RemoteList test
|
//TODO: this is a RemoteList test
|
||||||
//func Test_lhRemoteAllowList(t *testing.T) {
|
//func Test_lhRemoteAllowList(t *testing.T) {
|
||||||
// l := NewTestLogger()
|
// l := NewLogger()
|
||||||
// c := NewConfig(l)
|
// c := NewConfig(l)
|
||||||
// c.Settings["remoteallowlist"] = map[interface{}]interface{}{
|
// c.Settings["remoteallowlist"] = map[interface{}]interface{}{
|
||||||
// "10.20.0.0/12": false,
|
// "10.20.0.0/12": false,
|
||||||
|
|
33
logger.go
33
logger.go
|
@ -1,7 +1,6 @@
|
||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -10,38 +9,6 @@ import (
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ContextualError struct {
|
|
||||||
RealError error
|
|
||||||
Fields map[string]interface{}
|
|
||||||
Context string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
|
|
||||||
return ContextualError{Context: msg, Fields: fields, RealError: realError}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ce ContextualError) Error() string {
|
|
||||||
if ce.RealError == nil {
|
|
||||||
return ce.Context
|
|
||||||
}
|
|
||||||
return ce.RealError.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ce ContextualError) Unwrap() error {
|
|
||||||
if ce.RealError == nil {
|
|
||||||
return errors.New(ce.Context)
|
|
||||||
}
|
|
||||||
return ce.RealError
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ce *ContextualError) Log(lr *logrus.Logger) {
|
|
||||||
if ce.RealError != nil {
|
|
||||||
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
|
|
||||||
} else {
|
|
||||||
lr.WithFields(ce.Fields).Error(ce.Context)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func configLogger(l *logrus.Logger, c *config.C) error {
|
func configLogger(l *logrus.Logger, c *config.C) error {
|
||||||
// set up our logging level
|
// set up our logging level
|
||||||
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
|
logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
|
||||||
|
|
43
main.go
43
main.go
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/slackhq/nebula/iputil"
|
"github.com/slackhq/nebula/iputil"
|
||||||
"github.com/slackhq/nebula/sshd"
|
"github.com/slackhq/nebula/sshd"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||||
|
|
||||||
err := configLogger(l, c)
|
err := configLogger(l, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Failed to configure the logger", nil, err)
|
return nil, util.NewContextualError("Failed to configure the logger", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
@ -57,20 +58,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||||
caPool, err := loadCAFromConfig(l, c)
|
caPool, err := loadCAFromConfig(l, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//The errors coming out of loadCA are already nicely formatted
|
//The errors coming out of loadCA are already nicely formatted
|
||||||
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
return nil, util.NewContextualError("Failed to load ca from config", nil, err)
|
||||||
}
|
}
|
||||||
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
|
l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
|
||||||
|
|
||||||
cs, err := NewCertStateFromConfig(c)
|
cs, err := NewCertStateFromConfig(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//The errors coming out of NewCertStateFromConfig are already nicely formatted
|
//The errors coming out of NewCertStateFromConfig are already nicely formatted
|
||||||
return nil, NewContextualError("Failed to load certificate from config", nil, err)
|
return nil, util.NewContextualError("Failed to load certificate from config", nil, err)
|
||||||
}
|
}
|
||||||
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
||||||
|
|
||||||
fw, err := NewFirewallFromConfig(l, cs.certificate, c)
|
fw, err := NewFirewallFromConfig(l, cs.certificate, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Error while loading firewall rules", nil, err)
|
return nil, util.NewContextualError("Error while loading firewall rules", nil, err)
|
||||||
}
|
}
|
||||||
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
|
l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
|
||||||
|
|
||||||
|
@ -78,11 +79,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||||
tunCidr := cs.certificate.Details.Ips[0]
|
tunCidr := cs.certificate.Details.Ips[0]
|
||||||
routes, err := parseRoutes(c, tunCidr)
|
routes, err := parseRoutes(c, tunCidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Could not parse tun.routes", nil, err)
|
return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
|
||||||
}
|
}
|
||||||
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
|
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||||
|
@ -91,7 +92,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||||
if c.GetBool("sshd.enabled", false) {
|
if c.GetBool("sshd.enabled", false) {
|
||||||
sshStart, err = configSSH(l, ssh, c)
|
sshStart, err = configSSH(l, ssh, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Error while configuring the sshd", nil, err)
|
return nil, util.NewContextualError("Error while configuring the sshd", nil, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -167,7 +168,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Failed to get a tun/tap device", nil, err)
|
return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||||
for i := 0; i < routines; i++ {
|
for i := 0; i < routines; i++ {
|
||||||
udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64))
|
udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||||
}
|
}
|
||||||
udpServer.ReloadConfig(c)
|
udpServer.ReloadConfig(c)
|
||||||
udpConns[i] = udpServer
|
udpConns[i] = udpServer
|
||||||
|
@ -194,7 +195,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||||
if port == 0 {
|
if port == 0 {
|
||||||
uPort, err := udpServer.LocalAddr()
|
uPort, err := udpServer.LocalAddr()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Failed to get listening port", nil, err)
|
return nil, util.NewContextualError("Failed to get listening port", nil, err)
|
||||||
}
|
}
|
||||||
port = int(uPort.Port)
|
port = int(uPort.Port)
|
||||||
}
|
}
|
||||||
|
@ -209,7 +210,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||||
for _, rawPreferredRange := range rawPreferredRanges {
|
for _, rawPreferredRange := range rawPreferredRanges {
|
||||||
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
|
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Failed to parse preferred ranges", nil, err)
|
return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err)
|
||||||
}
|
}
|
||||||
preferredRanges = append(preferredRanges, preferredRange)
|
preferredRanges = append(preferredRanges, preferredRange)
|
||||||
}
|
}
|
||||||
|
@ -222,7 +223,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||||
if rawLocalRange != "" {
|
if rawLocalRange != "" {
|
||||||
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Failed to parse local_range", nil, err)
|
return nil, util.NewContextualError("Failed to parse local_range", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the entry for local_range was already specified in
|
// Check if the entry for local_range was already specified in
|
||||||
|
@ -261,7 +262,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||||
|
|
||||||
// fatal if am_lighthouse is enabled but we are using an ephemeral port
|
// fatal if am_lighthouse is enabled but we are using an ephemeral port
|
||||||
if amLighthouse && (c.GetInt("listen.port", 0) == 0) {
|
if amLighthouse && (c.GetInt("listen.port", 0) == 0) {
|
||||||
return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
|
return nil, util.NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// warn if am_lighthouse is enabled but upstream lighthouses exists
|
// warn if am_lighthouse is enabled but upstream lighthouses exists
|
||||||
|
@ -274,10 +275,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||||
for i, host := range rawLighthouseHosts {
|
for i, host := range rawLighthouseHosts {
|
||||||
ip := net.ParseIP(host)
|
ip := net.ParseIP(host)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
|
return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
|
||||||
}
|
}
|
||||||
if !tunCidr.Contains(ip) {
|
if !tunCidr.Contains(ip) {
|
||||||
return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
return nil, util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
|
||||||
}
|
}
|
||||||
lighthouseHosts[i] = iputil.Ip2VpnIp(ip)
|
lighthouseHosts[i] = iputil.Ip2VpnIp(ip)
|
||||||
}
|
}
|
||||||
|
@ -298,13 +299,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||||
|
|
||||||
remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
|
remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
|
return nil, util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
|
||||||
}
|
}
|
||||||
lightHouse.SetRemoteAllowList(remoteAllowList)
|
lightHouse.SetRemoteAllowList(remoteAllowList)
|
||||||
|
|
||||||
localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list")
|
localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
|
return nil, util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
|
||||||
}
|
}
|
||||||
lightHouse.SetLocalAllowList(localAllowList)
|
lightHouse.SetLocalAllowList(localAllowList)
|
||||||
|
|
||||||
|
@ -313,21 +314,21 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||||
ip := net.ParseIP(fmt.Sprintf("%v", k))
|
ip := net.ParseIP(fmt.Sprintf("%v", k))
|
||||||
vpnIp := iputil.Ip2VpnIp(ip)
|
vpnIp := iputil.Ip2VpnIp(ip)
|
||||||
if !tunCidr.Contains(ip) {
|
if !tunCidr.Contains(ip) {
|
||||||
return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
|
return nil, util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
|
||||||
}
|
}
|
||||||
vals, ok := v.([]interface{})
|
vals, ok := v.([]interface{})
|
||||||
if ok {
|
if ok {
|
||||||
for _, v := range vals {
|
for _, v := range vals {
|
||||||
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
|
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||||
}
|
}
|
||||||
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
|
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
|
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
|
||||||
}
|
}
|
||||||
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
|
lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
|
||||||
}
|
}
|
||||||
|
@ -426,7 +427,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||||
statsStart, err := startStats(l, c, buildVersion, configTest)
|
statsStart, err := startStats(l, c, buildVersion, configTest)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewContextualError("Failed to start stats emitter", nil, err)
|
return nil, util.NewContextualError("Failed to start stats emitter", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if configTest {
|
if configTest {
|
||||||
|
|
|
@ -5,12 +5,12 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewPunchyFromConfig(t *testing.T) {
|
func TestNewPunchyFromConfig(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
|
|
||||||
// Test defaults
|
// Test defaults
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package util
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
|
@ -1,4 +1,4 @@
|
||||||
package util
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -7,7 +7,7 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewTestLogger() *logrus.Logger {
|
func NewLogger() *logrus.Logger {
|
||||||
l := logrus.New()
|
l := logrus.New()
|
||||||
|
|
||||||
v := os.Getenv("TEST_LOGS")
|
v := os.Getenv("TEST_LOGS")
|
|
@ -6,12 +6,12 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_parseRoutes(t *testing.T) {
|
func Test_parseRoutes(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ func Test_parseRoutes(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_parseUnsafeRoutes(t *testing.T) {
|
func Test_parseUnsafeRoutes(t *testing.T) {
|
||||||
l := util.NewTestLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ContextualError struct {
|
||||||
|
RealError error
|
||||||
|
Fields map[string]interface{}
|
||||||
|
Context string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
|
||||||
|
return ContextualError{Context: msg, Fields: fields, RealError: realError}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ce ContextualError) Error() string {
|
||||||
|
if ce.RealError == nil {
|
||||||
|
return ce.Context
|
||||||
|
}
|
||||||
|
return ce.RealError.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ce ContextualError) Unwrap() error {
|
||||||
|
if ce.RealError == nil {
|
||||||
|
return errors.New(ce.Context)
|
||||||
|
}
|
||||||
|
return ce.RealError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ce *ContextualError) Log(lr *logrus.Logger) {
|
||||||
|
if ce.RealError != nil {
|
||||||
|
lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
|
||||||
|
} else {
|
||||||
|
lr.WithFields(ce.Fields).Error(ce.Context)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package nebula
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -8,6 +8,8 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type m map[string]interface{}
|
||||||
|
|
||||||
type TestLogWriter struct {
|
type TestLogWriter struct {
|
||||||
Logs []string
|
Logs []string
|
||||||
}
|
}
|
Loading…
Reference in New Issue