mirror of https://github.com/slackhq/nebula.git
Don't use a global logger (#423)
This commit is contained in:
parent
7a9f9dbded
commit
3ea7e1b75f
4
bits.go
4
bits.go
|
@ -26,7 +26,7 @@ func NewBits(bits uint64) *Bits {
|
|||
}
|
||||
}
|
||||
|
||||
func (b *Bits) Check(i uint64) bool {
|
||||
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
|
||||
// If i is the next number, return true.
|
||||
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
|
||||
return true
|
||||
|
@ -47,7 +47,7 @@ func (b *Bits) Check(i uint64) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (b *Bits) Update(i uint64) bool {
|
||||
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
|
||||
// If i is the next number, return true and update current.
|
||||
if i == b.current+1 {
|
||||
// Report missed packets, we can only understand what was missed after the first window has been gone through
|
||||
|
|
154
bits_test.go
154
bits_test.go
|
@ -7,6 +7,7 @@ import (
|
|||
)
|
||||
|
||||
func TestBits(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
b := NewBits(10)
|
||||
|
||||
// make sure it is the right size
|
||||
|
@ -14,46 +15,46 @@ func TestBits(t *testing.T) {
|
|||
|
||||
// This is initialized to zero - receive one. This should work.
|
||||
|
||||
assert.True(t, b.Check(1))
|
||||
u := b.Update(1)
|
||||
assert.True(t, b.Check(l, 1))
|
||||
u := b.Update(l, 1)
|
||||
assert.True(t, u)
|
||||
assert.EqualValues(t, 1, b.current)
|
||||
g := []bool{false, true, false, false, false, false, false, false, false, false}
|
||||
assert.Equal(t, g, b.bits)
|
||||
|
||||
// Receive two
|
||||
assert.True(t, b.Check(2))
|
||||
u = b.Update(2)
|
||||
assert.True(t, b.Check(l, 2))
|
||||
u = b.Update(l, 2)
|
||||
assert.True(t, u)
|
||||
assert.EqualValues(t, 2, b.current)
|
||||
g = []bool{false, true, true, false, false, false, false, false, false, false}
|
||||
assert.Equal(t, g, b.bits)
|
||||
|
||||
// Receive two again - it will fail
|
||||
assert.False(t, b.Check(2))
|
||||
u = b.Update(2)
|
||||
assert.False(t, b.Check(l, 2))
|
||||
u = b.Update(l, 2)
|
||||
assert.False(t, u)
|
||||
assert.EqualValues(t, 2, b.current)
|
||||
|
||||
// Jump ahead to 15, which should clear everything and set the 6th element
|
||||
assert.True(t, b.Check(15))
|
||||
u = b.Update(15)
|
||||
assert.True(t, b.Check(l, 15))
|
||||
u = b.Update(l, 15)
|
||||
assert.True(t, u)
|
||||
assert.EqualValues(t, 15, b.current)
|
||||
g = []bool{false, false, false, false, false, true, false, false, false, false}
|
||||
assert.Equal(t, g, b.bits)
|
||||
|
||||
// Mark 14, which is allowed because it is in the window
|
||||
assert.True(t, b.Check(14))
|
||||
u = b.Update(14)
|
||||
assert.True(t, b.Check(l, 14))
|
||||
u = b.Update(l, 14)
|
||||
assert.True(t, u)
|
||||
assert.EqualValues(t, 15, b.current)
|
||||
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
||||
assert.Equal(t, g, b.bits)
|
||||
|
||||
// Mark 5, which is not allowed because it is not in the window
|
||||
assert.False(t, b.Check(5))
|
||||
u = b.Update(5)
|
||||
assert.False(t, b.Check(l, 5))
|
||||
u = b.Update(l, 5)
|
||||
assert.False(t, u)
|
||||
assert.EqualValues(t, 15, b.current)
|
||||
g = []bool{false, false, false, false, true, true, false, false, false, false}
|
||||
|
@ -61,63 +62,65 @@ func TestBits(t *testing.T) {
|
|||
|
||||
// make sure we handle wrapping around once to the current position
|
||||
b = NewBits(10)
|
||||
assert.True(t, b.Update(1))
|
||||
assert.True(t, b.Update(11))
|
||||
assert.True(t, b.Update(l, 1))
|
||||
assert.True(t, b.Update(l, 11))
|
||||
assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits)
|
||||
|
||||
// Walk through a few windows in order
|
||||
b = NewBits(10)
|
||||
for i := uint64(0); i <= 100; i++ {
|
||||
assert.True(t, b.Check(i), "Error while checking %v", i)
|
||||
assert.True(t, b.Update(i), "Error while updating %v", i)
|
||||
assert.True(t, b.Check(l, i), "Error while checking %v", i)
|
||||
assert.True(t, b.Update(l, i), "Error while updating %v", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitsDupeCounter(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
b := NewBits(10)
|
||||
b.lostCounter.Clear()
|
||||
b.dupeCounter.Clear()
|
||||
b.outOfWindowCounter.Clear()
|
||||
|
||||
assert.True(t, b.Update(1))
|
||||
assert.True(t, b.Update(l, 1))
|
||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||
|
||||
assert.False(t, b.Update(1))
|
||||
assert.False(t, b.Update(l, 1))
|
||||
assert.Equal(t, int64(1), b.dupeCounter.Count())
|
||||
|
||||
assert.True(t, b.Update(2))
|
||||
assert.True(t, b.Update(l, 2))
|
||||
assert.Equal(t, int64(1), b.dupeCounter.Count())
|
||||
|
||||
assert.True(t, b.Update(3))
|
||||
assert.True(t, b.Update(l, 3))
|
||||
assert.Equal(t, int64(1), b.dupeCounter.Count())
|
||||
|
||||
assert.False(t, b.Update(1))
|
||||
assert.False(t, b.Update(l, 1))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.Equal(t, int64(2), b.dupeCounter.Count())
|
||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||
}
|
||||
|
||||
func TestBitsOutOfWindowCounter(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
b := NewBits(10)
|
||||
b.lostCounter.Clear()
|
||||
b.dupeCounter.Clear()
|
||||
b.outOfWindowCounter.Clear()
|
||||
|
||||
assert.True(t, b.Update(20))
|
||||
assert.True(t, b.Update(l, 20))
|
||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||
|
||||
assert.True(t, b.Update(21))
|
||||
assert.True(t, b.Update(22))
|
||||
assert.True(t, b.Update(23))
|
||||
assert.True(t, b.Update(24))
|
||||
assert.True(t, b.Update(25))
|
||||
assert.True(t, b.Update(26))
|
||||
assert.True(t, b.Update(27))
|
||||
assert.True(t, b.Update(28))
|
||||
assert.True(t, b.Update(29))
|
||||
assert.True(t, b.Update(l, 21))
|
||||
assert.True(t, b.Update(l, 22))
|
||||
assert.True(t, b.Update(l, 23))
|
||||
assert.True(t, b.Update(l, 24))
|
||||
assert.True(t, b.Update(l, 25))
|
||||
assert.True(t, b.Update(l, 26))
|
||||
assert.True(t, b.Update(l, 27))
|
||||
assert.True(t, b.Update(l, 28))
|
||||
assert.True(t, b.Update(l, 29))
|
||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||
|
||||
assert.False(t, b.Update(0))
|
||||
assert.False(t, b.Update(l, 0))
|
||||
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
|
||||
|
||||
//tODO: make sure lostcounter doesn't increase in orderly increment
|
||||
|
@ -127,23 +130,24 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBitsLostCounter(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
b := NewBits(10)
|
||||
b.lostCounter.Clear()
|
||||
b.dupeCounter.Clear()
|
||||
b.outOfWindowCounter.Clear()
|
||||
|
||||
//assert.True(t, b.Update(0))
|
||||
assert.True(t, b.Update(0))
|
||||
assert.True(t, b.Update(20))
|
||||
assert.True(t, b.Update(21))
|
||||
assert.True(t, b.Update(22))
|
||||
assert.True(t, b.Update(23))
|
||||
assert.True(t, b.Update(24))
|
||||
assert.True(t, b.Update(25))
|
||||
assert.True(t, b.Update(26))
|
||||
assert.True(t, b.Update(27))
|
||||
assert.True(t, b.Update(28))
|
||||
assert.True(t, b.Update(29))
|
||||
assert.True(t, b.Update(l, 0))
|
||||
assert.True(t, b.Update(l, 20))
|
||||
assert.True(t, b.Update(l, 21))
|
||||
assert.True(t, b.Update(l, 22))
|
||||
assert.True(t, b.Update(l, 23))
|
||||
assert.True(t, b.Update(l, 24))
|
||||
assert.True(t, b.Update(l, 25))
|
||||
assert.True(t, b.Update(l, 26))
|
||||
assert.True(t, b.Update(l, 27))
|
||||
assert.True(t, b.Update(l, 28))
|
||||
assert.True(t, b.Update(l, 29))
|
||||
assert.Equal(t, int64(20), b.lostCounter.Count())
|
||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
|
||||
|
@ -153,56 +157,56 @@ func TestBitsLostCounter(t *testing.T) {
|
|||
b.dupeCounter.Clear()
|
||||
b.outOfWindowCounter.Clear()
|
||||
|
||||
assert.True(t, b.Update(0))
|
||||
assert.True(t, b.Update(l, 0))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
assert.True(t, b.Update(9))
|
||||
assert.True(t, b.Update(l, 9))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
// 10 will set 0 index, 0 was already set, no lost packets
|
||||
assert.True(t, b.Update(10))
|
||||
assert.True(t, b.Update(l, 10))
|
||||
assert.Equal(t, int64(0), b.lostCounter.Count())
|
||||
// 11 will set 1 index, 1 was missed, we should see 1 packet lost
|
||||
assert.True(t, b.Update(11))
|
||||
assert.True(t, b.Update(l, 11))
|
||||
assert.Equal(t, int64(1), b.lostCounter.Count())
|
||||
// Now let's fill in the window, should end up with 8 lost packets
|
||||
assert.True(t, b.Update(12))
|
||||
assert.True(t, b.Update(13))
|
||||
assert.True(t, b.Update(14))
|
||||
assert.True(t, b.Update(15))
|
||||
assert.True(t, b.Update(16))
|
||||
assert.True(t, b.Update(17))
|
||||
assert.True(t, b.Update(18))
|
||||
assert.True(t, b.Update(19))
|
||||
assert.True(t, b.Update(l, 12))
|
||||
assert.True(t, b.Update(l, 13))
|
||||
assert.True(t, b.Update(l, 14))
|
||||
assert.True(t, b.Update(l, 15))
|
||||
assert.True(t, b.Update(l, 16))
|
||||
assert.True(t, b.Update(l, 17))
|
||||
assert.True(t, b.Update(l, 18))
|
||||
assert.True(t, b.Update(l, 19))
|
||||
assert.Equal(t, int64(8), b.lostCounter.Count())
|
||||
|
||||
// Jump ahead by a window size
|
||||
assert.True(t, b.Update(29))
|
||||
assert.True(t, b.Update(l, 29))
|
||||
assert.Equal(t, int64(8), b.lostCounter.Count())
|
||||
// Now lets walk ahead normally through the window, the missed packets should fill in
|
||||
assert.True(t, b.Update(30))
|
||||
assert.True(t, b.Update(31))
|
||||
assert.True(t, b.Update(32))
|
||||
assert.True(t, b.Update(33))
|
||||
assert.True(t, b.Update(34))
|
||||
assert.True(t, b.Update(35))
|
||||
assert.True(t, b.Update(36))
|
||||
assert.True(t, b.Update(37))
|
||||
assert.True(t, b.Update(38))
|
||||
assert.True(t, b.Update(l, 30))
|
||||
assert.True(t, b.Update(l, 31))
|
||||
assert.True(t, b.Update(l, 32))
|
||||
assert.True(t, b.Update(l, 33))
|
||||
assert.True(t, b.Update(l, 34))
|
||||
assert.True(t, b.Update(l, 35))
|
||||
assert.True(t, b.Update(l, 36))
|
||||
assert.True(t, b.Update(l, 37))
|
||||
assert.True(t, b.Update(l, 38))
|
||||
// 39 packets tracked, 22 seen, 17 lost
|
||||
assert.Equal(t, int64(17), b.lostCounter.Count())
|
||||
|
||||
// Jump ahead by 2 windows, should have recording 1 full window missing
|
||||
assert.True(t, b.Update(58))
|
||||
assert.True(t, b.Update(l, 58))
|
||||
assert.Equal(t, int64(27), b.lostCounter.Count())
|
||||
// Now lets walk ahead normally through the window, the missed packets should fill in from this window
|
||||
assert.True(t, b.Update(59))
|
||||
assert.True(t, b.Update(60))
|
||||
assert.True(t, b.Update(61))
|
||||
assert.True(t, b.Update(62))
|
||||
assert.True(t, b.Update(63))
|
||||
assert.True(t, b.Update(64))
|
||||
assert.True(t, b.Update(65))
|
||||
assert.True(t, b.Update(66))
|
||||
assert.True(t, b.Update(67))
|
||||
assert.True(t, b.Update(l, 59))
|
||||
assert.True(t, b.Update(l, 60))
|
||||
assert.True(t, b.Update(l, 61))
|
||||
assert.True(t, b.Update(l, 62))
|
||||
assert.True(t, b.Update(l, 63))
|
||||
assert.True(t, b.Update(l, 64))
|
||||
assert.True(t, b.Update(l, 65))
|
||||
assert.True(t, b.Update(l, 66))
|
||||
assert.True(t, b.Update(l, 67))
|
||||
// 68 packets tracked, 32 seen, 36 missed
|
||||
assert.Equal(t, int64(36), b.lostCounter.Count())
|
||||
assert.Equal(t, int64(0), b.dupeCounter.Count())
|
||||
|
|
3
cert.go
3
cert.go
|
@ -7,6 +7,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
)
|
||||
|
||||
|
@ -119,7 +120,7 @@ func NewCertStateFromConfig(c *Config) (*CertState, error) {
|
|||
return NewCertState(nebulaCert, rawKey)
|
||||
}
|
||||
|
||||
func loadCAFromConfig(c *Config) (*cert.NebulaCAPool, error) {
|
||||
func loadCAFromConfig(l *logrus.Logger, c *Config) (*cert.NebulaCAPool, error) {
|
||||
var rawCA []byte
|
||||
var err error
|
||||
|
||||
|
|
|
@ -46,15 +46,16 @@ func main() {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
config := nebula.NewConfig()
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
|
||||
config := nebula.NewConfig(l)
|
||||
err := config.Load(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to load config: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
||||
|
||||
switch v := err.(type) {
|
||||
|
|
|
@ -24,14 +24,15 @@ func (p *program) Start(s service.Service) error {
|
|||
// Start should not block.
|
||||
logger.Info("Nebula service starting.")
|
||||
|
||||
config := nebula.NewConfig()
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
|
||||
config := nebula.NewConfig(l)
|
||||
err := config.Load(*p.configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load config: %s", err)
|
||||
}
|
||||
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -40,15 +40,16 @@ func main() {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
config := nebula.NewConfig()
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
|
||||
config := nebula.NewConfig(l)
|
||||
err := config.Load(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to load config: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
l := logrus.New()
|
||||
l.Out = os.Stdout
|
||||
c, err := nebula.Main(config, *configTest, Build, l, nil)
|
||||
|
||||
switch v := err.(type) {
|
||||
|
|
18
config.go
18
config.go
|
@ -26,11 +26,13 @@ type Config struct {
|
|||
Settings map[interface{}]interface{}
|
||||
oldSettings map[interface{}]interface{}
|
||||
callbacks []func(*Config)
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func NewConfig() *Config {
|
||||
func NewConfig(l *logrus.Logger) *Config {
|
||||
return &Config{
|
||||
Settings: make(map[interface{}]interface{}),
|
||||
l: l,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -99,12 +101,12 @@ func (c *Config) HasChanged(k string) bool {
|
|||
|
||||
newVals, err := yaml.Marshal(nv)
|
||||
if err != nil {
|
||||
l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
|
||||
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
|
||||
}
|
||||
|
||||
oldVals, err := yaml.Marshal(ov)
|
||||
if err != nil {
|
||||
l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
|
||||
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
|
||||
}
|
||||
|
||||
return string(newVals) != string(oldVals)
|
||||
|
@ -118,7 +120,7 @@ func (c *Config) CatchHUP() {
|
|||
|
||||
go func() {
|
||||
for range ch {
|
||||
l.Info("Caught HUP, reloading config")
|
||||
c.l.Info("Caught HUP, reloading config")
|
||||
c.ReloadConfig()
|
||||
}
|
||||
}()
|
||||
|
@ -132,7 +134,7 @@ func (c *Config) ReloadConfig() {
|
|||
|
||||
err := c.Load(c.path)
|
||||
if err != nil {
|
||||
l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
|
||||
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -500,7 +502,7 @@ func configLogger(c *Config) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
|
||||
}
|
||||
l.SetLevel(logLevel)
|
||||
c.l.SetLevel(logLevel)
|
||||
|
||||
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
|
||||
timestampFormat := c.GetString("logging.timestamp_format", "")
|
||||
|
@ -512,13 +514,13 @@ func configLogger(c *Config) error {
|
|||
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
|
||||
switch logFormat {
|
||||
case "text":
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
c.l.Formatter = &logrus.TextFormatter{
|
||||
TimestampFormat: timestampFormat,
|
||||
FullTimestamp: fullTimestamp,
|
||||
DisableTimestamp: disableTimestamp,
|
||||
}
|
||||
case "json":
|
||||
l.Formatter = &logrus.JSONFormatter{
|
||||
c.l.Formatter = &logrus.JSONFormatter{
|
||||
TimestampFormat: timestampFormat,
|
||||
DisableTimestamp: disableTimestamp,
|
||||
}
|
||||
|
|
|
@ -11,14 +11,15 @@ import (
|
|||
)
|
||||
|
||||
func TestConfig_Load(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
dir, err := ioutil.TempDir("", "config-test")
|
||||
// invalid yaml
|
||||
c := NewConfig()
|
||||
c := NewConfig(l)
|
||||
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
|
||||
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
|
||||
|
||||
// simple multi config merge
|
||||
c = NewConfig()
|
||||
c = NewConfig(l)
|
||||
os.RemoveAll(dir)
|
||||
os.Mkdir(dir, 0755)
|
||||
|
||||
|
@ -40,8 +41,9 @@ func TestConfig_Load(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestConfig_Get(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
// test simple type
|
||||
c := NewConfig()
|
||||
c := NewConfig(l)
|
||||
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
|
||||
assert.Equal(t, "hi", c.Get("firewall.outbound"))
|
||||
|
||||
|
@ -55,13 +57,15 @@ func TestConfig_Get(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestConfig_GetStringSlice(t *testing.T) {
|
||||
c := NewConfig()
|
||||
l := NewTestLogger()
|
||||
c := NewConfig(l)
|
||||
c.Settings["slice"] = []interface{}{"one", "two"}
|
||||
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
|
||||
}
|
||||
|
||||
func TestConfig_GetBool(t *testing.T) {
|
||||
c := NewConfig()
|
||||
l := NewTestLogger()
|
||||
c := NewConfig(l)
|
||||
c.Settings["bool"] = true
|
||||
assert.Equal(t, true, c.GetBool("bool", false))
|
||||
|
||||
|
@ -88,7 +92,8 @@ func TestConfig_GetBool(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestConfig_GetAllowList(t *testing.T) {
|
||||
c := NewConfig()
|
||||
l := NewTestLogger()
|
||||
c := NewConfig(l)
|
||||
c.Settings["allowlist"] = map[interface{}]interface{}{
|
||||
"192.168.0.0": true,
|
||||
}
|
||||
|
@ -181,20 +186,21 @@ func TestConfig_GetAllowList(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestConfig_HasChanged(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
// No reload has occurred, return false
|
||||
c := NewConfig()
|
||||
c := NewConfig(l)
|
||||
c.Settings["test"] = "hi"
|
||||
assert.False(t, c.HasChanged(""))
|
||||
|
||||
// Test key change
|
||||
c = NewConfig()
|
||||
c = NewConfig(l)
|
||||
c.Settings["test"] = "hi"
|
||||
c.oldSettings = map[interface{}]interface{}{"test": "no"}
|
||||
assert.True(t, c.HasChanged("test"))
|
||||
assert.True(t, c.HasChanged(""))
|
||||
|
||||
// No key change
|
||||
c = NewConfig()
|
||||
c = NewConfig(l)
|
||||
c.Settings["test"] = "hi"
|
||||
c.oldSettings = map[interface{}]interface{}{"test": "hi"}
|
||||
assert.False(t, c.HasChanged("test"))
|
||||
|
@ -202,12 +208,13 @@ func TestConfig_HasChanged(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestConfig_ReloadConfig(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
done := make(chan bool, 1)
|
||||
dir, err := ioutil.TempDir("", "config-test")
|
||||
assert.Nil(t, err)
|
||||
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
|
||||
|
||||
c := NewConfig()
|
||||
c := NewConfig(l)
|
||||
assert.Nil(t, c.Load(dir))
|
||||
|
||||
assert.False(t, c.HasChanged("outer.inner"))
|
||||
|
|
|
@ -28,10 +28,11 @@ type connectionManager struct {
|
|||
checkInterval int
|
||||
pendingDeletionInterval int
|
||||
|
||||
l *logrus.Logger
|
||||
// I wanted to call one matLock
|
||||
}
|
||||
|
||||
func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
|
||||
func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
|
||||
nc := &connectionManager{
|
||||
hostMap: intf.hostMap,
|
||||
in: make(map[uint32]struct{}),
|
||||
|
@ -47,6 +48,7 @@ func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterva
|
|||
pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
|
||||
checkInterval: checkInterval,
|
||||
pendingDeletionInterval: pendingDeletionInterval,
|
||||
l: l,
|
||||
}
|
||||
nc.Start()
|
||||
return nc
|
||||
|
@ -166,8 +168,8 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
|
|||
|
||||
// If we saw incoming packets from this ip, just return
|
||||
if traf {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).
|
||||
if n.l.Level >= logrus.DebugLevel {
|
||||
n.l.WithField("vpnIp", IntIp(vpnIP)).
|
||||
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
|
||||
Debug("Tunnel status")
|
||||
}
|
||||
|
@ -179,13 +181,13 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
|
|||
// If we didn't we may need to probe or destroy the conn
|
||||
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
|
||||
if err != nil {
|
||||
l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
||||
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
||||
n.ClearIP(vpnIP)
|
||||
n.ClearPendingDeletion(vpnIP)
|
||||
continue
|
||||
}
|
||||
|
||||
hostinfo.logger().
|
||||
hostinfo.logger(n.l).
|
||||
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
|
||||
Debug("Tunnel status")
|
||||
|
||||
|
@ -194,7 +196,7 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
|
|||
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out)
|
||||
|
||||
} else {
|
||||
hostinfo.logger().Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
|
||||
hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
|
||||
}
|
||||
n.AddPendingDeletion(vpnIP)
|
||||
}
|
||||
|
@ -214,7 +216,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
|||
// If we saw incoming packets from this ip, just return
|
||||
traf := n.CheckIn(vpnIP)
|
||||
if traf {
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).
|
||||
n.l.WithField("vpnIp", IntIp(vpnIP)).
|
||||
WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
|
||||
Debug("Tunnel status")
|
||||
n.ClearIP(vpnIP)
|
||||
|
@ -226,7 +228,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
|||
if err != nil {
|
||||
n.ClearIP(vpnIP)
|
||||
n.ClearPendingDeletion(vpnIP)
|
||||
l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
||||
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -236,7 +238,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
|||
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
|
||||
cn = hostinfo.ConnectionState.peerCert.Details.Name
|
||||
}
|
||||
hostinfo.logger().
|
||||
hostinfo.logger(n.l).
|
||||
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
|
||||
WithField("certName", cn).
|
||||
Info("Tunnel status")
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
var vpnIP uint32
|
||||
|
||||
func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
|
@ -20,7 +21,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||
preferredRanges := []*net.IPNet{localrange}
|
||||
|
||||
// Very incomplete mock objects
|
||||
hostMap := NewHostMap("test", vpncidr, preferredRanges)
|
||||
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||
cs := &CertState{
|
||||
rawCertificate: []byte{},
|
||||
privateKey: []byte{},
|
||||
|
@ -28,7 +29,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||
rawCertificateNoKey: []byte{},
|
||||
}
|
||||
|
||||
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
||||
lh := NewLightHouse(l, false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
||||
ifce := &Interface{
|
||||
hostMap: hostMap,
|
||||
inside: &Tun{},
|
||||
|
@ -36,12 +37,13 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||
certState: cs,
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
||||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
||||
l: l,
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
// Create manager
|
||||
nc := newConnectionManager(ifce, 5, 10)
|
||||
nc := newConnectionManager(l, ifce, 5, 10)
|
||||
p := []byte("")
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
|
@ -79,13 +81,14 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
|
||||
// Very incomplete mock objects
|
||||
hostMap := NewHostMap("test", vpncidr, preferredRanges)
|
||||
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||
cs := &CertState{
|
||||
rawCertificate: []byte{},
|
||||
privateKey: []byte{},
|
||||
|
@ -93,7 +96,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||
rawCertificateNoKey: []byte{},
|
||||
}
|
||||
|
||||
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
||||
lh := NewLightHouse(l, false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
|
||||
ifce := &Interface{
|
||||
hostMap: hostMap,
|
||||
inside: &Tun{},
|
||||
|
@ -101,12 +104,13 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||
certState: cs,
|
||||
firewall: &Firewall{},
|
||||
lightHouse: lh,
|
||||
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
||||
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
|
||||
l: l,
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
// Create manager
|
||||
nc := newConnectionManager(ifce, 5, 10)
|
||||
nc := newConnectionManager(l, ifce, 5, 10)
|
||||
p := []byte("")
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"sync/atomic"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
)
|
||||
|
||||
|
@ -26,7 +27,7 @@ type ConnectionState struct {
|
|||
ready bool
|
||||
}
|
||||
|
||||
func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
||||
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
||||
cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256)
|
||||
if f.cipher == "chachapoly" {
|
||||
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
|
@ -37,7 +38,7 @@ func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePa
|
|||
|
||||
b := NewBits(ReplayWindow)
|
||||
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
|
||||
b.Update(0)
|
||||
b.Update(l, 0)
|
||||
|
||||
hs, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: cs,
|
||||
|
|
|
@ -13,9 +13,10 @@ import (
|
|||
)
|
||||
|
||||
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
||||
// To properly ensure we are not exposing core memory to the caller
|
||||
hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
|
||||
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
|
||||
remote1 := NewUDPAddr(int2ip(100), 4444)
|
||||
remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
|
||||
ipNet := net.IPNet{
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// This whole thing should be rewritten to use context
|
||||
|
@ -63,7 +64,7 @@ func (d *dnsRecords) Add(host, data string) {
|
|||
d.Unlock()
|
||||
}
|
||||
|
||||
func parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||
func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
|
||||
for _, q := range m.Question {
|
||||
switch q.Qtype {
|
||||
case dns.TypeA:
|
||||
|
@ -95,34 +96,38 @@ func parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
|||
}
|
||||
}
|
||||
|
||||
func handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Compress = false
|
||||
|
||||
switch r.Opcode {
|
||||
case dns.OpcodeQuery:
|
||||
parseQuery(m, w)
|
||||
parseQuery(l, m, w)
|
||||
}
|
||||
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func dnsMain(hostMap *HostMap, c *Config) {
|
||||
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) {
|
||||
dnsR = newDnsRecords(hostMap)
|
||||
|
||||
// attach request handler func
|
||||
dns.HandleFunc(".", handleDnsRequest)
|
||||
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
handleDnsRequest(l, w, r)
|
||||
})
|
||||
|
||||
c.RegisterReloadCallback(reloadDns)
|
||||
startDns(c)
|
||||
c.RegisterReloadCallback(func(c *Config) {
|
||||
reloadDns(l, c)
|
||||
})
|
||||
startDns(l, c)
|
||||
}
|
||||
|
||||
func getDnsServerAddr(c *Config) string {
|
||||
return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
|
||||
}
|
||||
|
||||
func startDns(c *Config) {
|
||||
func startDns(l *logrus.Logger, c *Config) {
|
||||
dnsAddr = getDnsServerAddr(c)
|
||||
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
||||
l.Debugf("Starting DNS responder at %s\n", dnsAddr)
|
||||
|
@ -133,7 +138,7 @@ func startDns(c *Config) {
|
|||
}
|
||||
}
|
||||
|
||||
func reloadDns(c *Config) {
|
||||
func reloadDns(l *logrus.Logger, c *Config) {
|
||||
if dnsAddr == getDnsServerAddr(c) {
|
||||
l.Debug("No DNS server config change detected")
|
||||
return
|
||||
|
@ -141,5 +146,5 @@ func reloadDns(c *Config) {
|
|||
|
||||
l.Debug("Restarting DNS server")
|
||||
dnsServer.Shutdown()
|
||||
go startDns(c)
|
||||
go startDns(l, c)
|
||||
}
|
||||
|
|
31
firewall.go
31
firewall.go
|
@ -70,6 +70,7 @@ type Firewall struct {
|
|||
|
||||
trackTCPRTT bool
|
||||
metricTCPRTT metrics.Histogram
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type FirewallConntrack struct {
|
||||
|
@ -156,7 +157,7 @@ func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||
func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
|
||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
|
||||
//TODO: error on 0 duration
|
||||
var min, max time.Duration
|
||||
|
||||
|
@ -195,11 +196,13 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
|
|||
DefaultTimeout: defaultTimeout,
|
||||
localIps: localIps,
|
||||
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||
l: l,
|
||||
}
|
||||
}
|
||||
|
||||
func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
|
||||
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
|
||||
fw := NewFirewall(
|
||||
l,
|
||||
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
|
||||
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
|
||||
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
|
||||
|
@ -207,12 +210,12 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
|
|||
//TODO: max_connections
|
||||
)
|
||||
|
||||
err := AddFirewallRulesFromConfig(false, c, fw)
|
||||
err := AddFirewallRulesFromConfig(l, false, c, fw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = AddFirewallRulesFromConfig(true, c, fw)
|
||||
err = AddFirewallRulesFromConfig(l, true, c, fw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -240,7 +243,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||
if !incoming {
|
||||
direction = "outgoing"
|
||||
}
|
||||
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
|
||||
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
|
||||
Info("Firewall rule added")
|
||||
|
||||
var (
|
||||
|
@ -276,7 +279,7 @@ func (f *Firewall) GetRuleHash() string {
|
|||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterface) error {
|
||||
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
|
||||
var table string
|
||||
if inbound {
|
||||
table = "firewall.inbound"
|
||||
|
@ -296,7 +299,7 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
|
|||
|
||||
for i, t := range rs {
|
||||
var groups []string
|
||||
r, err := convertRule(t, table, i)
|
||||
r, err := convertRule(l, t, table, i)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s rule #%v; %s", table, i, err)
|
||||
}
|
||||
|
@ -459,8 +462,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
|
|||
|
||||
// We now know which firewall table to check against
|
||||
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
h.logger().
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
h.logger(f.l).
|
||||
WithField("fwPacket", fp).
|
||||
WithField("incoming", c.incoming).
|
||||
WithField("rulesVersion", f.rulesVersion).
|
||||
|
@ -472,8 +475,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
|
|||
return false
|
||||
}
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
h.logger().
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
h.logger(f.l).
|
||||
WithField("fwPacket", fp).
|
||||
WithField("incoming", c.incoming).
|
||||
WithField("rulesVersion", f.rulesVersion).
|
||||
|
@ -795,7 +798,7 @@ type rule struct {
|
|||
CASha string
|
||||
}
|
||||
|
||||
func convertRule(p interface{}, table string, i int) (rule, error) {
|
||||
func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
|
||||
r := rule{}
|
||||
|
||||
m, ok := p.(map[interface{}]interface{})
|
||||
|
@ -968,14 +971,14 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) {
|
|||
|
||||
// Get checks if the cache ticker has moved to the next version before returning
|
||||
// the map. If it has moved, we reset the map.
|
||||
func (c *ConntrackCacheTicker) Get() ConntrackCache {
|
||||
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
|
||||
c.cacheV = tick
|
||||
if ll := len(c.cache); ll > 0 {
|
||||
if l.GetLevel() == logrus.DebugLevel {
|
||||
if l.Level == logrus.DebugLevel {
|
||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
||||
}
|
||||
c.cache = make(ConntrackCache, ll)
|
||||
|
|
149
firewall_test.go
149
firewall_test.go
|
@ -15,8 +15,9 @@ import (
|
|||
)
|
||||
|
||||
func TestNewFirewall(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
c := &cert.NebulaCertificate{}
|
||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
conntrack := fw.Conntrack
|
||||
assert.NotNil(t, conntrack)
|
||||
assert.NotNil(t, conntrack.Conns)
|
||||
|
@ -31,35 +32,34 @@ func TestNewFirewall(t *testing.T) {
|
|||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
|
||||
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
|
||||
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
|
||||
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
|
||||
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
||||
}
|
||||
|
||||
func TestFirewall_AddRule(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
out := l.Out
|
||||
l.SetOutput(ob)
|
||||
defer l.SetOutput(out)
|
||||
|
||||
c := &cert.NebulaCertificate{}
|
||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.NotNil(t, fw.InRules)
|
||||
assert.NotNil(t, fw.OutRules)
|
||||
|
||||
|
@ -74,7 +74,7 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
|
||||
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
|
||||
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
|
||||
assert.False(t, fw.InRules.UDP[1].Any.Any)
|
||||
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
|
||||
|
@ -83,7 +83,7 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
|
||||
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
|
||||
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
|
||||
assert.False(t, fw.InRules.ICMP[1].Any.Any)
|
||||
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
||||
|
@ -92,23 +92,23 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
|
||||
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
|
||||
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
|
||||
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
|
||||
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
|
||||
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
||||
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
|
||||
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
||||
|
||||
// Set any and clear fields
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
|
||||
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
|
||||
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
|
||||
|
@ -125,26 +125,25 @@ func TestFirewall_AddRule(t *testing.T) {
|
|||
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
|
||||
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
|
||||
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
|
||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||
|
||||
// Test error conditions
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
|
||||
assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
out := l.Out
|
||||
l.SetOutput(ob)
|
||||
defer l.SetOutput(out)
|
||||
|
||||
p := FirewallPacket{
|
||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||
|
@ -177,7 +176,7 @@ func TestFirewall_Drop(t *testing.T) {
|
|||
}
|
||||
h.CreateRemoteCIDR(&c)
|
||||
|
||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
|
@ -196,27 +195,27 @@ func TestFirewall_Drop(t *testing.T) {
|
|||
p.RemoteIP = oldRemote
|
||||
|
||||
// ensure signer doesn't get in the way of group checks
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caSha doesn't drop on match
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
||||
|
||||
// ensure ca name doesn't get in the way of group checks
|
||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
|
||||
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caName doesn't drop on match
|
||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
|
||||
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
|
||||
|
@ -317,10 +316,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||
}
|
||||
|
||||
func TestFirewall_Drop2(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
out := l.Out
|
||||
l.SetOutput(ob)
|
||||
defer l.SetOutput(out)
|
||||
|
||||
p := FirewallPacket{
|
||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||
|
@ -365,7 +363,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||
}
|
||||
h1.CreateRemoteCIDR(&c1)
|
||||
|
||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
|
@ -377,10 +375,9 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFirewall_Drop3(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
out := l.Out
|
||||
l.SetOutput(ob)
|
||||
defer l.SetOutput(out)
|
||||
|
||||
p := FirewallPacket{
|
||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||
|
@ -448,7 +445,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||
}
|
||||
h3.CreateRemoteCIDR(&c3)
|
||||
|
||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
|
||||
cp := cert.NewCAPool()
|
||||
|
@ -464,10 +461,9 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
out := l.Out
|
||||
l.SetOutput(ob)
|
||||
defer l.SetOutput(out)
|
||||
|
||||
p := FirewallPacket{
|
||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||
|
@ -500,7 +496,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||
}
|
||||
h.CreateRemoteCIDR(&c)
|
||||
|
||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
|
@ -513,7 +509,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
|
||||
|
||||
oldFw := fw
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
@ -522,7 +518,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
|
||||
|
||||
oldFw = fw
|
||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
@ -647,124 +643,126 @@ func Test_parsePort(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNewFirewallFromConfig(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
// Test a bad rule definition
|
||||
c := &cert.NebulaCertificate{}
|
||||
conf := NewConfig()
|
||||
conf := NewConfig(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
|
||||
_, err := NewFirewallFromConfig(c, conf)
|
||||
_, err := NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
||||
|
||||
// Test both port and code
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
|
||||
_, err = NewFirewallFromConfig(c, conf)
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
||||
|
||||
// Test missing host, group, cidr, ca_name and ca_sha
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
|
||||
_, err = NewFirewallFromConfig(c, conf)
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
|
||||
|
||||
// Test code/port error
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
|
||||
_, err = NewFirewallFromConfig(c, conf)
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
||||
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
|
||||
_, err = NewFirewallFromConfig(c, conf)
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
||||
|
||||
// Test proto error
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
|
||||
_, err = NewFirewallFromConfig(c, conf)
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
||||
|
||||
// Test cidr parse error
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(c, conf)
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
|
||||
|
||||
// Test both group and groups
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
||||
_, err = NewFirewallFromConfig(c, conf)
|
||||
_, err = NewFirewallFromConfig(l, c, conf)
|
||||
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
||||
}
|
||||
|
||||
func TestAddFirewallRulesFromConfig(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
// Test adding tcp rule
|
||||
conf := NewConfig()
|
||||
conf := NewConfig(l)
|
||||
mf := &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||
|
||||
// Test adding udp rule
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||
|
||||
// Test adding icmp rule
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||
|
||||
// Test adding any rule
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
|
||||
|
||||
// Test adding rule with ca_sha
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
|
||||
|
||||
// Test adding rule with ca_name
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
|
||||
|
||||
// Test single group
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
|
||||
|
||||
// Test single groups
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
|
||||
|
||||
// Test multiple AND groups
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
mf = &mockFirewall{}
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
|
||||
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
|
||||
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
|
||||
|
||||
// Test Add error
|
||||
conf = NewConfig()
|
||||
conf = NewConfig(l)
|
||||
mf = &mockFirewall{}
|
||||
mf.nextCallReturn = errors.New("test error")
|
||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
|
||||
assert.EqualError(t, AddFirewallRulesFromConfig(true, conf, mf), "firewall.inbound rule #0; `test error`")
|
||||
assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
|
||||
}
|
||||
|
||||
func TestTCPRTTTracking(t *testing.T) {
|
||||
|
@ -859,17 +857,16 @@ func TestTCPRTTTracking(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFirewall_convertRule(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
out := l.Out
|
||||
l.SetOutput(ob)
|
||||
defer l.SetOutput(out)
|
||||
|
||||
// Ensure group array of 1 is converted and a warning is printed
|
||||
c := map[interface{}]interface{}{
|
||||
"group": []interface{}{"group1"},
|
||||
}
|
||||
|
||||
r, err := convertRule(c, "test", 1)
|
||||
r, err := convertRule(l, c, "test", 1)
|
||||
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "group1", r.Group)
|
||||
|
@ -880,7 +877,7 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||
"group": []interface{}{"group1", "group2"},
|
||||
}
|
||||
|
||||
r, err = convertRule(c, "test", 1)
|
||||
r, err = convertRule(l, c, "test", 1)
|
||||
assert.Equal(t, "", ob.String())
|
||||
assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
|
||||
|
||||
|
@ -890,7 +887,7 @@ func TestFirewall_convertRule(t *testing.T) {
|
|||
"group": "group1",
|
||||
}
|
||||
|
||||
r, err = convertRule(c, "test", 1)
|
||||
r, err = convertRule(l, c, "test", 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "group1", r.Group)
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ const (
|
|||
|
||||
func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
|
||||
if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
|
||||
l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
|||
|
||||
err := f.handshakeManager.AddIndexHostInfo(hostinfo)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
||||
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||
return
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
|||
hsBytes, err = proto.Marshal(hs)
|
||||
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
||||
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||
return
|
||||
}
|
||||
|
@ -58,14 +58,14 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
|||
|
||||
msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
||||
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||
return
|
||||
}
|
||||
|
||||
// We are sending handshake packet 1, so we don't expect to receive
|
||||
// handshake packet 1 from the responder
|
||||
ci.window.Update(1)
|
||||
ci.window.Update(f.l, 1)
|
||||
|
||||
hostinfo.HandshakePacket[0] = msg
|
||||
hostinfo.HandshakeReady = true
|
||||
|
@ -74,13 +74,13 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
|||
}
|
||||
|
||||
func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
||||
ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0)
|
||||
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
|
||||
// Mark packet 1 as seen so it doesn't show up as missed
|
||||
ci.window.Update(1)
|
||||
ci.window.Update(f.l, 1)
|
||||
|
||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
|
||||
return
|
||||
}
|
||||
|
@ -91,14 +91,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
|||
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
|
||||
*/
|
||||
if err != nil || hs.Details == nil {
|
||||
l.WithError(err).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
||||
return
|
||||
}
|
||||
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||
Info("Invalid certificate from host")
|
||||
return
|
||||
|
@ -108,16 +108,16 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
|||
fingerprint, _ := remoteCert.Sha256Sum()
|
||||
|
||||
if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) {
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||
return
|
||||
}
|
||||
|
||||
myIndex, err := generateIndex()
|
||||
myIndex, err := generateIndex(f.l)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||
|
@ -133,7 +133,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
|||
HandshakePacket: make(map[uint8][]byte, 0),
|
||||
}
|
||||
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
|
@ -145,7 +145,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
|||
|
||||
hsBytes, err := proto.Marshal(hs)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||
|
@ -155,13 +155,13 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
|||
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
||||
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||
return
|
||||
} else if dKey == nil || eKey == nil {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
|
||||
|
@ -178,7 +178,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
|||
|
||||
// We are sending handshake packet 2, so we don't expect to receive
|
||||
// handshake packet 2 from the initiator.
|
||||
ci.window.Update(2)
|
||||
ci.window.Update(f.l, 2)
|
||||
|
||||
ci.peerCert = remoteCert
|
||||
ci.dKey = NewNebulaCipherState(dKey)
|
||||
|
@ -203,11 +203,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
|||
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
|
||||
err := f.outside.WriteTo(msg, addr)
|
||||
if err != nil {
|
||||
l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
WithError(err).Error("Failed to send handshake message")
|
||||
} else {
|
||||
l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
Info("Handshake message sent")
|
||||
}
|
||||
|
@ -215,7 +215,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
|||
case ErrExistingHostInfo:
|
||||
// This means there was an existing tunnel and we didn't win
|
||||
// handshake avoidance
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
|
@ -227,7 +227,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
|||
return
|
||||
case ErrLocalIndexCollision:
|
||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
|
@ -238,7 +238,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
|||
default:
|
||||
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
||||
// And we forget to update it here
|
||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
|
@ -252,14 +252,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
|||
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
|
||||
err = f.outside.WriteTo(msg, addr)
|
||||
if err != nil {
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithError(err).Error("Failed to send handshake")
|
||||
} else {
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
|
@ -267,7 +267,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
|
|||
Info("Handshake message sent")
|
||||
}
|
||||
|
||||
hostinfo.handshakeComplete()
|
||||
hostinfo.handshakeComplete(f.l)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -280,7 +280,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||
defer hostinfo.Unlock()
|
||||
|
||||
if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
||||
Info("Already seen this handshake packet")
|
||||
return false
|
||||
|
@ -288,14 +288,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||
|
||||
ci := hostinfo.ConnectionState
|
||||
// Mark packet 2 as seen so it doesn't show up as missed
|
||||
ci.window.Update(2)
|
||||
ci.window.Update(f.l, 2)
|
||||
|
||||
hostinfo.HandshakePacket[2] = make([]byte, len(packet[HeaderLen:]))
|
||||
copy(hostinfo.HandshakePacket[2], packet[HeaderLen:])
|
||||
|
||||
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
||||
Error("Failed to call noise.ReadMessage")
|
||||
|
||||
|
@ -304,7 +304,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||
// near future
|
||||
return false
|
||||
} else if dKey == nil || eKey == nil {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Error("Noise did not arrive at a key")
|
||||
return true
|
||||
|
@ -313,14 +313,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||
hs := &NebulaHandshake{}
|
||||
err = proto.Unmarshal(msg, hs)
|
||||
if err != nil || hs.Details == nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
||||
return true
|
||||
}
|
||||
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
Error("Invalid certificate from host")
|
||||
return true
|
||||
|
@ -330,7 +330,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||
fingerprint, _ := remoteCert.Sha256Sum()
|
||||
|
||||
duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
|
@ -362,7 +362,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||
hostinfo.CreateRemoteCIDR(remoteCert)
|
||||
|
||||
f.handshakeManager.Complete(hostinfo, f)
|
||||
hostinfo.handshakeComplete()
|
||||
hostinfo.handshakeComplete(f.l)
|
||||
f.metricHandshakes.Update(duration)
|
||||
|
||||
return false
|
||||
|
|
|
@ -53,11 +53,12 @@ type HandshakeManager struct {
|
|||
InboundHandshakeTimer *SystemTimerWheel
|
||||
|
||||
messageMetrics *MessageMetrics
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
|
||||
func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
|
||||
return &HandshakeManager{
|
||||
pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges),
|
||||
pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
|
||||
mainHostMap: mainHostMap,
|
||||
lightHouse: lightHouse,
|
||||
outside: outside,
|
||||
|
@ -70,6 +71,7 @@ func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainH
|
|||
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
|
||||
|
||||
messageMetrics: config.messageMetrics,
|
||||
l: l,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,7 +80,7 @@ func (c *HandshakeManager) Run(f EncWriter) {
|
|||
for {
|
||||
select {
|
||||
case vpnIP := <-c.trigger:
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
|
||||
c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
|
||||
c.handleOutbound(vpnIP, f, true)
|
||||
case now := <-clockSource:
|
||||
c.NextOutboundHandshakeTimerTick(now, f)
|
||||
|
@ -149,7 +151,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
|
|||
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
||||
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
|
||||
if err != nil {
|
||||
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
|
||||
hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
|
||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
|
@ -157,7 +159,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
|
|||
} else {
|
||||
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
|
||||
// keep the real packet struct around for logging purposes
|
||||
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
|
||||
hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
|
||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
|
@ -245,7 +247,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
|
|||
if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
|
||||
// We have a collision, but this can happen since we can't control
|
||||
// the remote ID. Just log about the situation as a note.
|
||||
hostinfo.logger().
|
||||
hostinfo.logger(c.l).
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
|
||||
Info("New host shadows existing host remoteIndex")
|
||||
}
|
||||
|
@ -280,7 +282,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
|
|||
if found && existingRemoteIndex != nil {
|
||||
// We have a collision, but this can happen since we can't control
|
||||
// the remote ID. Just log about the situation as a note.
|
||||
hostinfo.logger().
|
||||
hostinfo.logger(c.l).
|
||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
|
||||
Info("New host shadows existing host remoteIndex")
|
||||
}
|
||||
|
@ -298,7 +300,7 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
|
|||
defer c.mainHostMap.RUnlock()
|
||||
|
||||
for i := 0; i < 32; i++ {
|
||||
index, err := generateIndex()
|
||||
index, err := generateIndex(c.l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -336,7 +338,7 @@ func (c *HandshakeManager) EmitStats() {
|
|||
|
||||
// Utility functions below
|
||||
|
||||
func generateIndex() (uint32, error) {
|
||||
func generateIndex(l *logrus.Logger) (uint32, error) {
|
||||
b := make([]byte, 4)
|
||||
|
||||
// Let zero mean we don't know the ID, so don't generate zero
|
||||
|
|
|
@ -12,15 +12,15 @@ import (
|
|||
var ips []uint32
|
||||
|
||||
func Test_NewHandshakeManagerIndex(t *testing.T) {
|
||||
|
||||
l := NewTestLogger()
|
||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
||||
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||
|
||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||
|
||||
now := time.Now()
|
||||
blah.NextInboundHandshakeTimerTick(now)
|
||||
|
@ -63,15 +63,16 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
mw := &mockEncWriter{}
|
||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
||||
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||
|
||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||
|
||||
now := time.Now()
|
||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||
|
@ -112,16 +113,17 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_NewHandshakeManagerTrigger(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
ip := ip2int(net.ParseIP("172.1.1.2"))
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
mw := &mockEncWriter{}
|
||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
||||
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||
lh := &LightHouse{}
|
||||
|
||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
|
||||
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
|
||||
|
||||
now := time.Now()
|
||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||
|
@ -162,15 +164,16 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
|
|||
}
|
||||
|
||||
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
vpnIP = ip2int(net.ParseIP("172.1.1.2"))
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
mw := &mockEncWriter{}
|
||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
||||
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||
|
||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||
|
||||
now := time.Now()
|
||||
blah.NextOutboundHandshakeTimerTick(now, mw)
|
||||
|
@ -216,13 +219,14 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
mainHM := NewHostMap("test", vpncidr, preferredRanges)
|
||||
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
|
||||
|
||||
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
|
||||
|
||||
now := time.Now()
|
||||
blah.NextInboundHandshakeTimerTick(now)
|
||||
|
|
54
hostmap.go
54
hostmap.go
|
@ -33,6 +33,7 @@ type HostMap struct {
|
|||
defaultRoute uint32
|
||||
unsafeRoutes *CIDRTree
|
||||
metricsEnabled bool
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type HostInfo struct {
|
||||
|
@ -83,7 +84,7 @@ type Probe struct {
|
|||
Counter int
|
||||
}
|
||||
|
||||
func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
|
||||
func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
|
||||
h := map[uint32]*HostInfo{}
|
||||
i := map[uint32]*HostInfo{}
|
||||
r := map[uint32]*HostInfo{}
|
||||
|
@ -96,6 +97,7 @@ func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *
|
|||
vpnCIDR: vpnCIDR,
|
||||
defaultRoute: 0,
|
||||
unsafeRoutes: NewCIDRTree(),
|
||||
l: l,
|
||||
}
|
||||
return &m
|
||||
}
|
||||
|
@ -160,8 +162,8 @@ func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
|
|||
}
|
||||
hm.Unlock()
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
|
||||
Debug("Hostmap vpnIp deleted")
|
||||
}
|
||||
}
|
||||
|
@ -173,8 +175,8 @@ func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
|
|||
hm.RemoteIndexes[index] = h
|
||||
hm.Unlock()
|
||||
|
||||
if l.Level > logrus.DebugLevel {
|
||||
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
|
||||
if hm.l.Level > logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
|
||||
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
|
||||
Debug("Hostmap remoteIndex added")
|
||||
}
|
||||
|
@ -188,8 +190,8 @@ func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
|
|||
hm.RemoteIndexes[h.remoteIndexId] = h
|
||||
hm.Unlock()
|
||||
|
||||
if l.Level > logrus.DebugLevel {
|
||||
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
|
||||
if hm.l.Level > logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
|
||||
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
|
||||
Debug("Hostmap vpnIp added")
|
||||
}
|
||||
|
@ -212,8 +214,8 @@ func (hm *HostMap) DeleteIndex(index uint32) {
|
|||
}
|
||||
hm.Unlock()
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
|
||||
Debug("Hostmap index deleted")
|
||||
}
|
||||
}
|
||||
|
@ -236,8 +238,8 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
|
|||
}
|
||||
hm.Unlock()
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
|
||||
Debug("Hostmap remote index deleted")
|
||||
}
|
||||
}
|
||||
|
@ -269,8 +271,8 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
|
|||
}
|
||||
hm.Unlock()
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
|
||||
"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
||||
Debug("Hostmap hostInfo deleted")
|
||||
}
|
||||
|
@ -313,8 +315,10 @@ func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo {
|
|||
}
|
||||
i.remote = i.Remotes[0].addr
|
||||
hm.Hosts[vpnIp] = i
|
||||
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
|
||||
Debug("Hostmap remote ip added")
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
|
||||
Debug("Hostmap remote ip added")
|
||||
}
|
||||
}
|
||||
i.ForcePromoteBest(hm.preferredRanges)
|
||||
hm.Unlock()
|
||||
|
@ -377,8 +381,8 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
|
|||
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
||||
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
|
||||
if hm.l.Level >= logrus.DebugLevel {
|
||||
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
|
||||
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
|
||||
Debug("Hostmap vpnIp added")
|
||||
}
|
||||
|
@ -436,7 +440,7 @@ func (hm *HostMap) Punchy(conn *udpConn) {
|
|||
|
||||
func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
|
||||
for _, r := range *routes {
|
||||
l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
|
||||
hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
|
||||
hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
|
||||
}
|
||||
}
|
||||
|
@ -566,7 +570,7 @@ func (i *HostInfo) rotateRemote() {
|
|||
i.remote = i.Remotes[0].addr
|
||||
}
|
||||
|
||||
func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
|
||||
func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
|
||||
//TODO: return the error so we can log with more context
|
||||
if len(i.packetStore) < 100 {
|
||||
tempPacket := make([]byte, len(packet))
|
||||
|
@ -574,14 +578,14 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
|
|||
//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
|
||||
i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
i.logger().
|
||||
i.logger(l).
|
||||
WithField("length", len(i.packetStore)).
|
||||
WithField("stored", true).
|
||||
Debugf("Packet store")
|
||||
}
|
||||
|
||||
} else if l.Level >= logrus.DebugLevel {
|
||||
i.logger().
|
||||
i.logger(l).
|
||||
WithField("length", len(i.packetStore)).
|
||||
WithField("stored", false).
|
||||
Debugf("Packet store")
|
||||
|
@ -589,7 +593,7 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
|
|||
}
|
||||
|
||||
// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets
|
||||
func (i *HostInfo) handshakeComplete() {
|
||||
func (i *HostInfo) handshakeComplete(l *logrus.Logger) {
|
||||
//TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because:
|
||||
//TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send
|
||||
//TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical
|
||||
|
@ -601,7 +605,7 @@ func (i *HostInfo) handshakeComplete() {
|
|||
atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2)
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
i.logger().Debugf("Sending %d stored packets", len(i.packetStore))
|
||||
i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore))
|
||||
}
|
||||
|
||||
if len(i.packetStore) > 0 {
|
||||
|
@ -689,7 +693,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
|
|||
i.remoteCidr = remoteCidr
|
||||
}
|
||||
|
||||
func (i *HostInfo) logger() *logrus.Entry {
|
||||
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
||||
if i == nil {
|
||||
return logrus.NewEntry(l)
|
||||
}
|
||||
|
@ -804,7 +808,7 @@ func (d *HostInfoDest) ProbeReceived(probeCount int) {
|
|||
|
||||
// Utility functions
|
||||
|
||||
func localIps(allowList *AllowList) *[]net.IP {
|
||||
func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP {
|
||||
//FIXME: This function is pretty garbage
|
||||
var ips []net.IP
|
||||
ifaces, _ := net.Interfaces()
|
||||
|
|
|
@ -64,12 +64,13 @@ func TestHostInfoDestProbe(t *testing.T) {
|
|||
*/
|
||||
|
||||
func TestHostmap(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
|
||||
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
|
||||
myNets := []*net.IPNet{myNet}
|
||||
preferredRanges := []*net.IPNet{localToMe}
|
||||
|
||||
m := NewHostMap("test", myNet, preferredRanges)
|
||||
m := NewHostMap(l, "test", myNet, preferredRanges)
|
||||
|
||||
a := NewUDPAddrFromString("10.127.0.3:11111")
|
||||
b := NewUDPAddrFromString("1.0.0.1:22222")
|
||||
|
@ -103,10 +104,11 @@ func TestHostmap(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHostmapdebug(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
|
||||
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
|
||||
preferredRanges := []*net.IPNet{localToMe}
|
||||
m := NewHostMap("test", myNet, preferredRanges)
|
||||
m := NewHostMap(l, "test", myNet, preferredRanges)
|
||||
|
||||
a := NewUDPAddrFromString("10.127.0.3:11111")
|
||||
b := NewUDPAddrFromString("1.0.0.1:22222")
|
||||
|
@ -151,11 +153,12 @@ func TestHostMap_rotateRemote(t *testing.T) {
|
|||
}
|
||||
|
||||
func BenchmarkHostmappromote2(b *testing.B) {
|
||||
l := NewTestLogger()
|
||||
for n := 0; n < b.N; n++ {
|
||||
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
|
||||
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
|
||||
preferredRanges := []*net.IPNet{localToMe}
|
||||
m := NewHostMap("test", myNet, preferredRanges)
|
||||
m := NewHostMap(l, "test", myNet, preferredRanges)
|
||||
y := NewUDPAddrFromString("10.128.0.3:11111")
|
||||
a := NewUDPAddrFromString("10.127.0.3:11111")
|
||||
g := NewUDPAddrFromString("1.0.0.1:22222")
|
||||
|
|
40
inside.go
40
inside.go
|
@ -10,7 +10,7 @@ import (
|
|||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) {
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -31,8 +31,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
|||
|
||||
hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
|
||||
if hostinfo == nil {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
|
||||
WithField("fwPacket", fwPacket).
|
||||
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
|||
// the packet queue.
|
||||
ci.queueLock.Lock()
|
||||
if !ci.ready {
|
||||
hostinfo.cachePacket(message, 0, packet, f.sendMessageNow)
|
||||
hostinfo.cachePacket(f.l, message, 0, packet, f.sendMessageNow)
|
||||
ci.queueLock.Unlock()
|
||||
return
|
||||
}
|
||||
|
@ -59,8 +59,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
|
|||
f.lightHouse.Query(fwPacket.RemoteIP, f)
|
||||
}
|
||||
|
||||
} else if l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger().
|
||||
} else if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).
|
||||
WithField("fwPacket", fwPacket).
|
||||
WithField("reason", dropReason).
|
||||
Debugln("dropping outbound packet")
|
||||
|
@ -104,7 +104,7 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
|
|||
|
||||
if ci == nil {
|
||||
// if we don't have a connection state, then send a handshake initiation
|
||||
ci = f.newConnectionState(true, noise.HandshakeIX, []byte{}, 0)
|
||||
ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
|
||||
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
|
||||
//ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0)
|
||||
hostinfo.ConnectionState = ci
|
||||
|
@ -135,15 +135,15 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
|
|||
fp := &FirewallPacket{}
|
||||
err := newPacket(p, false, fp)
|
||||
if err != nil {
|
||||
l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
|
||||
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// check if packet is in outbound fw rules
|
||||
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil)
|
||||
if dropReason != nil {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("fwPacket", fp).
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("fwPacket", fp).
|
||||
WithField("reason", dropReason).
|
||||
Debugln("dropping cached packet")
|
||||
}
|
||||
|
@ -160,8 +160,8 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
|
|||
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
|
||||
hostInfo := f.getOrHandshake(vpnIp)
|
||||
if hostInfo == nil {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("vpnIp", IntIp(vpnIp)).
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnIp", IntIp(vpnIp)).
|
||||
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
|
||||
}
|
||||
return
|
||||
|
@ -172,7 +172,7 @@ func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
|
|||
// the packet queue.
|
||||
hostInfo.ConnectionState.queueLock.Lock()
|
||||
if !hostInfo.ConnectionState.ready {
|
||||
hostInfo.cachePacket(t, st, p, f.sendMessageToVpnIp)
|
||||
hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToVpnIp)
|
||||
hostInfo.ConnectionState.queueLock.Unlock()
|
||||
return
|
||||
}
|
||||
|
@ -191,8 +191,8 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
|
|||
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
|
||||
hostInfo := f.getOrHandshake(vpnIp)
|
||||
if hostInfo == nil {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("vpnIp", IntIp(vpnIp)).
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnIp", IntIp(vpnIp)).
|
||||
Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
|
||||
}
|
||||
return
|
||||
|
@ -203,7 +203,7 @@ func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubTyp
|
|||
// the packet queue.
|
||||
hostInfo.ConnectionState.queueLock.Lock()
|
||||
if !hostInfo.ConnectionState.ready {
|
||||
hostInfo.cachePacket(t, st, p, f.sendMessageToAll)
|
||||
hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToAll)
|
||||
hostInfo.ConnectionState.queueLock.Unlock()
|
||||
return
|
||||
}
|
||||
|
@ -247,8 +247,8 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
|
|||
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
||||
f.lightHouse.Query(hostinfo.hostId, f)
|
||||
hostinfo.lastRebindCount = f.rebindCount
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -256,7 +256,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
|
|||
//TODO: see above note on lock
|
||||
//ci.writeLock.Unlock()
|
||||
if err != nil {
|
||||
hostinfo.logger().WithError(err).
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", remote).WithField("counter", c).
|
||||
WithField("attemptedCounter", c).
|
||||
Error("Failed to encrypt outgoing packet")
|
||||
|
@ -265,7 +265,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
|
|||
|
||||
err = f.writers[q].WriteTo(out, remote)
|
||||
if err != nil {
|
||||
hostinfo.logger().WithError(err).
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||
}
|
||||
return c
|
||||
|
|
40
interface.go
40
interface.go
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const mtu = 9001
|
||||
|
@ -42,6 +43,7 @@ type InterfaceConfig struct {
|
|||
version string
|
||||
|
||||
ConntrackCacheTimeout time.Duration
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type Interface struct {
|
||||
|
@ -73,6 +75,7 @@ type Interface struct {
|
|||
|
||||
metricHandshakes metrics.Histogram
|
||||
messageMetrics *MessageMetrics
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
||||
|
@ -113,9 +116,10 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
|
|||
|
||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||
messageMetrics: c.MessageMetrics,
|
||||
l: c.l,
|
||||
}
|
||||
|
||||
ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval)
|
||||
ifce.connectionManager = newConnectionManager(c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
|
||||
|
||||
return ifce, nil
|
||||
}
|
||||
|
@ -125,10 +129,10 @@ func (f *Interface) run() {
|
|||
|
||||
addr, err := f.outside.LocalAddr()
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to get udp listen address")
|
||||
f.l.WithError(err).Error("Failed to get udp listen address")
|
||||
}
|
||||
|
||||
l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
|
||||
f.l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
|
||||
WithField("build", f.version).WithField("udpAddr", addr).
|
||||
Info("Nebula interface is active")
|
||||
|
||||
|
@ -140,14 +144,14 @@ func (f *Interface) run() {
|
|||
if i > 0 {
|
||||
reader, err = f.inside.NewMultiQueueReader()
|
||||
if err != nil {
|
||||
l.Fatal(err)
|
||||
f.l.Fatal(err)
|
||||
}
|
||||
}
|
||||
f.readers[i] = reader
|
||||
}
|
||||
|
||||
if err := f.inside.Activate(); err != nil {
|
||||
l.Fatal(err)
|
||||
f.l.Fatal(err)
|
||||
}
|
||||
|
||||
// Launch n queues to read packets from udp
|
||||
|
@ -187,12 +191,12 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
|||
for {
|
||||
n, err := reader.Read(packet)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Error while reading outbound packet")
|
||||
f.l.WithError(err).Error("Error while reading outbound packet")
|
||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
|
||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -208,21 +212,21 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *Config) {
|
|||
func (f *Interface) reloadCA(c *Config) {
|
||||
// reload and check regardless
|
||||
// todo: need mutex?
|
||||
newCAs, err := loadCAFromConfig(c)
|
||||
newCAs, err := loadCAFromConfig(f.l, c)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Could not refresh trusted CA certificates")
|
||||
f.l.WithError(err).Error("Could not refresh trusted CA certificates")
|
||||
return
|
||||
}
|
||||
|
||||
trustedCAs = newCAs
|
||||
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
|
||||
f.l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
|
||||
}
|
||||
|
||||
func (f *Interface) reloadCertKey(c *Config) {
|
||||
// reload and check in all cases
|
||||
cs, err := NewCertStateFromConfig(c)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Could not refresh client cert")
|
||||
f.l.WithError(err).Error("Could not refresh client cert")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -230,24 +234,24 @@ func (f *Interface) reloadCertKey(c *Config) {
|
|||
oldIPs := f.certState.certificate.Details.Ips
|
||||
newIPs := cs.certificate.Details.Ips
|
||||
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
|
||||
l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
|
||||
f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
|
||||
return
|
||||
}
|
||||
|
||||
f.certState = cs
|
||||
l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
|
||||
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
|
||||
}
|
||||
|
||||
func (f *Interface) reloadFirewall(c *Config) {
|
||||
//TODO: need to trigger/detect if the certificate changed too
|
||||
if c.HasChanged("firewall") == false {
|
||||
l.Debug("No firewall config change detected")
|
||||
f.l.Debug("No firewall config change detected")
|
||||
return
|
||||
}
|
||||
|
||||
fw, err := NewFirewallFromConfig(f.certState.certificate, c)
|
||||
fw, err := NewFirewallFromConfig(f.l, f.certState.certificate, c)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Error while creating firewall during reload")
|
||||
f.l.WithError(err).Error("Error while creating firewall during reload")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -260,7 +264,7 @@ func (f *Interface) reloadFirewall(c *Config) {
|
|||
// If rulesVersion is back to zero, we have wrapped all the way around. Be
|
||||
// safe and just reset conntrack in this case.
|
||||
if fw.rulesVersion == 0 {
|
||||
l.WithField("firewallHash", fw.GetRuleHash()).
|
||||
f.l.WithField("firewallHash", fw.GetRuleHash()).
|
||||
WithField("oldFirewallHash", oldFw.GetRuleHash()).
|
||||
WithField("rulesVersion", fw.rulesVersion).
|
||||
Warn("firewall rulesVersion has overflowed, resetting conntrack")
|
||||
|
@ -271,7 +275,7 @@ func (f *Interface) reloadFirewall(c *Config) {
|
|||
f.firewall = fw
|
||||
|
||||
oldFw.Destroy()
|
||||
l.WithField("firewallHash", fw.GetRuleHash()).
|
||||
f.l.WithField("firewallHash", fw.GetRuleHash()).
|
||||
WithField("oldFirewallHash", oldFw.GetRuleHash()).
|
||||
WithField("rulesVersion", fw.rulesVersion).
|
||||
Info("New firewall has been installed")
|
||||
|
|
|
@ -48,6 +48,7 @@ type LightHouse struct {
|
|||
|
||||
metrics *MessageMetrics
|
||||
metricHolepunchTx metrics.Counter
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type EncWriter interface {
|
||||
|
@ -55,7 +56,7 @@ type EncWriter interface {
|
|||
SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
|
||||
}
|
||||
|
||||
func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
|
||||
func NewLightHouse(l *logrus.Logger, amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
|
||||
h := LightHouse{
|
||||
amLighthouse: amLighthouse,
|
||||
myIp: myIp,
|
||||
|
@ -67,6 +68,7 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n
|
|||
punchConn: pc,
|
||||
punchBack: punchBack,
|
||||
punchDelay: punchDelay,
|
||||
l: l,
|
||||
}
|
||||
|
||||
if metricsEnabled {
|
||||
|
@ -126,7 +128,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
|
|||
// Send a query to the lighthouses and hope for the best next time
|
||||
query, err := proto.Marshal(NewLhQueryByInt(ip))
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
|
||||
lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -159,7 +161,7 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
|
|||
lh.Lock()
|
||||
//l.Debugln(lh.addrMap)
|
||||
delete(lh.addrMap, vpnIP)
|
||||
l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
|
||||
lh.l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
|
||||
lh.Unlock()
|
||||
}
|
||||
|
||||
|
@ -181,7 +183,7 @@ func (lh *LightHouse) AddRemote(vpnIP uint32, toIp *udpAddr, static bool) {
|
|||
}
|
||||
|
||||
allow := lh.remoteAllowList.Allow(toIp.IP)
|
||||
l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow")
|
||||
lh.l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow")
|
||||
if !allow {
|
||||
return
|
||||
}
|
||||
|
@ -270,7 +272,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
|
|||
var v4 []*IpAndPort
|
||||
var v6 []*Ip6AndPort
|
||||
|
||||
for _, e := range *localIps(lh.localAllowList) {
|
||||
for _, e := range *localIps(lh.l, lh.localAllowList) {
|
||||
// Only add IPs that aren't my VPN/tun IP
|
||||
if ip2int(e) != lh.myIp {
|
||||
ipp := NewIpAndPort(e, lh.nebulaPort)
|
||||
|
@ -297,7 +299,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
|
|||
for vpnIp := range lh.lighthouses {
|
||||
mm, err := proto.Marshal(m)
|
||||
if err != nil {
|
||||
l.Debugf("Invalid marshal to update")
|
||||
lh.l.Debugf("Invalid marshal to update")
|
||||
}
|
||||
//l.Error("LIGHTHOUSE PACKET SEND", mm)
|
||||
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
|
||||
|
@ -368,14 +370,14 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
|||
n := lhh.resetMeta()
|
||||
err := proto.UnmarshalMerge(p, n)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
|
||||
lh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
|
||||
Error("Failed to unmarshal lighthouse packet")
|
||||
//TODO: send recv_error?
|
||||
return
|
||||
}
|
||||
|
||||
if n.Details == nil {
|
||||
l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
|
||||
lh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
|
||||
Error("Invalid lighthouse update")
|
||||
//TODO: send recv_error?
|
||||
return
|
||||
|
@ -387,7 +389,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
|||
case NebulaMeta_HostQuery:
|
||||
// Exit if we don't answer queries
|
||||
if !lh.amLighthouse {
|
||||
l.Debugln("I don't answer queries, but received from: ", rAddr)
|
||||
lh.l.Debugln("I don't answer queries, but received from: ", rAddr)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -422,7 +424,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
|||
|
||||
reply, err := proto.Marshal(n)
|
||||
if err != nil {
|
||||
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
|
||||
lh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
|
||||
return
|
||||
}
|
||||
lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
||||
|
@ -431,7 +433,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
|||
// This signals the other side to punch some zero byte udp packets
|
||||
ips, err = lh.Query(vpnIp, f)
|
||||
if err != nil {
|
||||
l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch")
|
||||
lh.l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch")
|
||||
return
|
||||
} else {
|
||||
//l.Debugln("Notify host to punch", iap)
|
||||
|
@ -492,7 +494,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
|||
case NebulaMeta_HostUpdateNotification:
|
||||
//Simple check that the host sent this not someone else
|
||||
if n.Details.VpnIp != vpnIp {
|
||||
l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
|
||||
lh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -530,9 +532,9 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
|||
|
||||
}()
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
if lh.l.Level >= logrus.DebugLevel {
|
||||
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
|
||||
l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
|
||||
lh.l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -549,9 +551,9 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
|||
|
||||
}()
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
if lh.l.Level >= logrus.DebugLevel {
|
||||
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
|
||||
l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
|
||||
lh.l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -561,7 +563,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
|
|||
if lh.punchBack {
|
||||
go func() {
|
||||
time.Sleep(time.Second * 5)
|
||||
l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
|
||||
lh.l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
|
||||
// TODO we have to allocate a new output buffer here since we are spawning a new goroutine
|
||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
||||
// managed by a channel.
|
||||
|
|
|
@ -65,12 +65,13 @@ func TestSetipandportsfromudpaddrs(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_lhStaticMapping(t *testing.T) {
|
||||
l := NewTestLogger()
|
||||
lh1 := "10.128.0.2"
|
||||
lh1IP := net.ParseIP(lh1)
|
||||
|
||||
udpServer, _ := NewListener("0.0.0.0", 0, true)
|
||||
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
|
||||
|
||||
meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||
meh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
|
||||
err := meh.ValidateLHStaticEntries()
|
||||
assert.Nil(t, err)
|
||||
|
@ -78,19 +79,20 @@ func Test_lhStaticMapping(t *testing.T) {
|
|||
lh2 := "10.128.0.3"
|
||||
lh2IP := net.ParseIP(lh2)
|
||||
|
||||
meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
|
||||
meh = NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
|
||||
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
|
||||
err = meh.ValidateLHStaticEntries()
|
||||
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
|
||||
}
|
||||
|
||||
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||
l := NewTestLogger()
|
||||
lh1 := "10.128.0.2"
|
||||
lh1IP := net.ParseIP(lh1)
|
||||
|
||||
udpServer, _ := NewListener("0.0.0.0", 0, true)
|
||||
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
|
||||
|
||||
lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||
lh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||
|
||||
hAddr := NewUDPAddrFromString("4.5.6.7:12345")
|
||||
hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
|
||||
|
@ -136,7 +138,8 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||
}
|
||||
|
||||
func Test_lhRemoteAllowList(t *testing.T) {
|
||||
c := NewConfig()
|
||||
l := NewTestLogger()
|
||||
c := NewConfig(l)
|
||||
c.Settings["remoteallowlist"] = map[interface{}]interface{}{
|
||||
"10.20.0.0/12": false,
|
||||
}
|
||||
|
@ -146,9 +149,9 @@ func Test_lhRemoteAllowList(t *testing.T) {
|
|||
lh1 := "10.128.0.2"
|
||||
lh1IP := net.ParseIP(lh1)
|
||||
|
||||
udpServer, _ := NewListener("0.0.0.0", 0, true)
|
||||
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
|
||||
|
||||
lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||
lh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
|
||||
lh.SetRemoteAllowList(allowList)
|
||||
|
||||
remote1 := "10.20.0.3"
|
||||
|
|
29
main.go
29
main.go
|
@ -11,13 +11,10 @@ import (
|
|||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// The caller should provide a real logger, we have one just in case
|
||||
var l = logrus.New()
|
||||
|
||||
type m map[string]interface{}
|
||||
|
||||
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
|
||||
l = logger
|
||||
l := logger
|
||||
l.Formatter = &logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
}
|
||||
|
@ -46,7 +43,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
})
|
||||
|
||||
// trustedCAs is currently a global, so loadCA operates on that global directly
|
||||
trustedCAs, err = loadCAFromConfig(config)
|
||||
trustedCAs, err = loadCAFromConfig(l, config)
|
||||
if err != nil {
|
||||
//The errors coming out of loadCA are already nicely formatted
|
||||
return nil, NewContextualError("Failed to load ca from config", nil, err)
|
||||
|
@ -60,7 +57,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
}
|
||||
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
|
||||
|
||||
fw, err := NewFirewallFromConfig(cs.certificate, config)
|
||||
fw, err := NewFirewallFromConfig(l, cs.certificate, config)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Error while loading firewall rules", nil, err)
|
||||
}
|
||||
|
@ -78,9 +75,9 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
}
|
||||
|
||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||
wireSSHReload(ssh, config)
|
||||
wireSSHReload(l, ssh, config)
|
||||
if config.GetBool("sshd.enabled", false) {
|
||||
err = configSSH(ssh, config)
|
||||
err = configSSH(l, ssh, config)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Error while configuring the sshd", nil, err)
|
||||
}
|
||||
|
@ -136,6 +133,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l)
|
||||
case tunFd != nil:
|
||||
tun, err = newTunFromFd(
|
||||
l,
|
||||
*tunFd,
|
||||
tunCidr,
|
||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||
|
@ -145,6 +143,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
)
|
||||
default:
|
||||
tun, err = newTun(
|
||||
l,
|
||||
config.GetString("tun.dev", ""),
|
||||
tunCidr,
|
||||
config.GetInt("tun.mtu", DEFAULT_MTU),
|
||||
|
@ -166,7 +165,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
|
||||
if !configTest {
|
||||
for i := 0; i < routines; i++ {
|
||||
udpServer, err := NewListener(config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
|
||||
udpServer, err := NewListener(l, config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||
}
|
||||
|
@ -222,7 +221,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
}
|
||||
}
|
||||
|
||||
hostMap := NewHostMap("main", tunCidr, preferredRanges)
|
||||
hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
|
||||
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
|
||||
hostMap.addUnsafeRoutes(&unsafeRoutes)
|
||||
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
|
||||
|
@ -266,6 +265,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
}
|
||||
|
||||
lightHouse := NewLightHouse(
|
||||
l,
|
||||
amLighthouse,
|
||||
ip2int(tunCidr.IP),
|
||||
lighthouseHosts,
|
||||
|
@ -337,7 +337,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
messageMetrics: messageMetrics,
|
||||
}
|
||||
|
||||
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
|
||||
handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
|
||||
lightHouse.handshakeTrigger = handshakeManager.trigger
|
||||
|
||||
//TODO: These will be reused for psk
|
||||
|
@ -367,6 +367,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
version: buildVersion,
|
||||
|
||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||
l: l,
|
||||
}
|
||||
|
||||
switch ifConfig.Cipher {
|
||||
|
@ -395,7 +396,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
go lightHouse.LhUpdateWorker(ifce)
|
||||
}
|
||||
|
||||
err = startStats(config, configTest)
|
||||
err = startStats(l, config, configTest)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to start stats emitter", nil, err)
|
||||
}
|
||||
|
@ -407,12 +408,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
|
|||
//TODO: check if we _should_ be emitting stats
|
||||
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
|
||||
|
||||
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
|
||||
attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
|
||||
|
||||
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
|
||||
if amLighthouse && serveDns {
|
||||
l.Debugln("Starting dns server")
|
||||
go dnsMain(hostMap, config)
|
||||
go dnsMain(l, hostMap, config)
|
||||
}
|
||||
|
||||
return &Control{ifce, l}, nil
|
||||
|
|
29
main_test.go
29
main_test.go
|
@ -1 +1,30 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func NewTestLogger() *logrus.Logger {
|
||||
l := logrus.New()
|
||||
|
||||
v := os.Getenv("TEST_LOGS")
|
||||
if v == "" {
|
||||
l.SetOutput(ioutil.Discard)
|
||||
return l
|
||||
}
|
||||
|
||||
switch v {
|
||||
case "1":
|
||||
// This is the default level but we are being explicit
|
||||
l.SetLevel(logrus.InfoLevel)
|
||||
case "2":
|
||||
l.SetLevel(logrus.DebugLevel)
|
||||
case "3":
|
||||
l.SetLevel(logrus.TraceLevel)
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
|
50
outside.go
50
outside.go
|
@ -24,7 +24,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
|||
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
|
||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||
if len(packet) > 1 {
|
||||
l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
|
||||
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
|||
|
||||
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
|
||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
|
||||
WithField("packet", packet).
|
||||
Error("Failed to decrypt lighthouse packet")
|
||||
|
||||
|
@ -78,7 +78,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
|||
|
||||
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
|
||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
|
||||
WithField("packet", packet).
|
||||
Error("Failed to decrypt test packet")
|
||||
|
||||
|
@ -115,7 +115,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
|||
return
|
||||
}
|
||||
|
||||
hostinfo.logger().WithField("udpAddr", addr).
|
||||
hostinfo.logger(f.l).WithField("udpAddr", addr).
|
||||
Info("Close tunnel received, tearing down.")
|
||||
|
||||
f.closeTunnel(hostinfo)
|
||||
|
@ -123,7 +123,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
|
|||
|
||||
default:
|
||||
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
|
||||
hostinfo.logger().Debugf("Unexpected packet received from %s", addr)
|
||||
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -143,18 +143,18 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
|
|||
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
|
||||
if hostDidRoam(hostinfo.remote, addr) {
|
||||
if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
|
||||
hostinfo.logger().WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
|
||||
hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
|
||||
return
|
||||
}
|
||||
if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
|
||||
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
|
||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
|
||||
Info("Host roamed to new udp ip/port.")
|
||||
hostinfo.lastRoam = time.Now()
|
||||
remoteCopy := *hostinfo.remote
|
||||
|
@ -170,7 +170,7 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
|
|||
func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool {
|
||||
// If connectionstate exists and the replay protector allows, process packet
|
||||
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
|
||||
if ci == nil || !ci.window.Check(header.MessageCounter) {
|
||||
if ci == nil || !ci.window.Check(f.l, header.MessageCounter) {
|
||||
f.sendRecvError(addr, header.RemoteIndex)
|
||||
return false
|
||||
}
|
||||
|
@ -247,8 +247,8 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if !hostinfo.ConnectionState.window.Update(mc) {
|
||||
hostinfo.logger().WithField("header", header).
|
||||
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
|
||||
hostinfo.logger(f.l).WithField("header", header).
|
||||
Debugln("dropping out of window packet")
|
||||
return nil, errors.New("out of window packet")
|
||||
}
|
||||
|
@ -261,7 +261,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||
|
||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger().WithError(err).Error("Failed to decrypt packet")
|
||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||
//TODO: maybe after build 64 is out? 06/14/2018 - NB
|
||||
//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
|
||||
return
|
||||
|
@ -269,21 +269,21 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||
|
||||
err = newPacket(out, true, fwPacket)
|
||||
if err != nil {
|
||||
hostinfo.logger().WithError(err).WithField("packet", out).
|
||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
||||
Warnf("Error while validating inbound packet")
|
||||
return
|
||||
}
|
||||
|
||||
if !hostinfo.ConnectionState.window.Update(messageCounter) {
|
||||
hostinfo.logger().WithField("fwPacket", fwPacket).
|
||||
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
|
||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||
Debugln("dropping out of window packet")
|
||||
return
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache)
|
||||
if dropReason != nil {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger().WithField("fwPacket", fwPacket).
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||
WithField("reason", dropReason).
|
||||
Debugln("dropping inbound packet")
|
||||
}
|
||||
|
@ -293,7 +293,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||
f.connectionManager.In(hostinfo.hostId)
|
||||
_, err = f.readers[q].Write(out)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to write to tun")
|
||||
f.l.WithError(err).Error("Failed to write to tun")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -303,16 +303,16 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
|
|||
//TODO: this should be a signed message so we can trust that we should drop the index
|
||||
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
|
||||
f.outside.WriteTo(b, endpoint)
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("index", index).
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("index", index).
|
||||
WithField("udpAddr", endpoint).
|
||||
Debug("Recv error sent")
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
l.WithField("index", h.RemoteIndex).
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("index", h.RemoteIndex).
|
||||
WithField("udpAddr", addr).
|
||||
Debug("Recv error received")
|
||||
}
|
||||
|
@ -322,7 +322,7 @@ func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
|
|||
|
||||
hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex)
|
||||
if err != nil {
|
||||
l.Debugln(err, ": ", h.RemoteIndex)
|
||||
f.l.Debugln(err, ": ", h.RemoteIndex)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -333,7 +333,7 @@ func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
|
|||
return
|
||||
}
|
||||
if hostinfo.remote != nil && hostinfo.remote.String() != addr.String() {
|
||||
l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
||||
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,8 @@ import (
|
|||
)
|
||||
|
||||
func TestNewPunchyFromConfig(t *testing.T) {
|
||||
c := NewConfig()
|
||||
l := NewTestLogger()
|
||||
c := NewConfig(l)
|
||||
|
||||
// Test defaults
|
||||
p := NewPunchyFromConfig(c)
|
||||
|
|
20
ssh.go
20
ssh.go
|
@ -44,10 +44,10 @@ type sshCreateTunnelFlags struct {
|
|||
Address string
|
||||
}
|
||||
|
||||
func wireSSHReload(ssh *sshd.SSHServer, c *Config) {
|
||||
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
|
||||
c.RegisterReloadCallback(func(c *Config) {
|
||||
if c.GetBool("sshd.enabled", false) {
|
||||
err := configSSH(ssh, c)
|
||||
err := configSSH(l, ssh, c)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to reconfigure the sshd")
|
||||
ssh.Stop()
|
||||
|
@ -58,7 +58,7 @@ func wireSSHReload(ssh *sshd.SSHServer, c *Config) {
|
|||
})
|
||||
}
|
||||
|
||||
func configSSH(ssh *sshd.SSHServer, c *Config) error {
|
||||
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) error {
|
||||
//TODO conntrack list
|
||||
//TODO print firewall rules or hash?
|
||||
|
||||
|
@ -149,7 +149,7 @@ func configSSH(ssh *sshd.SSHServer, c *Config) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
|
||||
func attachCommands(l *logrus.Logger, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "list-hostmap",
|
||||
ShortDescription: "List all known previously connected hosts",
|
||||
|
@ -225,13 +225,17 @@ func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostM
|
|||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "log-level",
|
||||
ShortDescription: "Gets or sets the current log level",
|
||||
Callback: sshLogLevel,
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
return sshLogLevel(l, fs, a, w)
|
||||
},
|
||||
})
|
||||
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
Name: "log-format",
|
||||
ShortDescription: "Gets or sets the current log format",
|
||||
Callback: sshLogFormat,
|
||||
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
return sshLogFormat(l, fs, a, w)
|
||||
},
|
||||
})
|
||||
|
||||
ssh.RegisterCommand(&sshd.Command{
|
||||
|
@ -629,7 +633,7 @@ func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
if len(a) == 0 {
|
||||
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
||||
}
|
||||
|
@ -643,7 +647,7 @@ func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error {
|
|||
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
|
||||
}
|
||||
|
||||
func sshLogFormat(fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
|
||||
if len(a) == 0 {
|
||||
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
|
||||
}
|
||||
|
|
11
stats.go
11
stats.go
|
@ -13,9 +13,10 @@ import (
|
|||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func startStats(c *Config, configTest bool) error {
|
||||
func startStats(l *logrus.Logger, c *Config, configTest bool) error {
|
||||
mType := c.GetString("stats.type", "")
|
||||
if mType == "" || mType == "none" {
|
||||
return nil
|
||||
|
@ -28,9 +29,9 @@ func startStats(c *Config, configTest bool) error {
|
|||
|
||||
switch mType {
|
||||
case "graphite":
|
||||
startGraphiteStats(interval, c, configTest)
|
||||
startGraphiteStats(l, interval, c, configTest)
|
||||
case "prometheus":
|
||||
startPrometheusStats(interval, c, configTest)
|
||||
startPrometheusStats(l, interval, c, configTest)
|
||||
default:
|
||||
return fmt.Errorf("stats.type was not understood: %s", mType)
|
||||
}
|
||||
|
@ -44,7 +45,7 @@ func startStats(c *Config, configTest bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func startGraphiteStats(i time.Duration, c *Config, configTest bool) error {
|
||||
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
|
||||
proto := c.GetString("stats.protocol", "tcp")
|
||||
host := c.GetString("stats.host", "")
|
||||
if host == "" {
|
||||
|
@ -64,7 +65,7 @@ func startGraphiteStats(i time.Duration, c *Config, configTest bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func startPrometheusStats(i time.Duration, c *Config, configTest bool) error {
|
||||
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
|
||||
namespace := c.GetString("stats.namespace", "")
|
||||
subsystem := c.GetString("stats.subsystem", "")
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
|
@ -19,9 +20,10 @@ type Tun struct {
|
|||
TXQueueLen int
|
||||
Routes []route
|
||||
UnsafeRoutes []route
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
ifce = &Tun{
|
||||
|
@ -33,6 +35,7 @@ func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route,
|
|||
TXQueueLen: txQueueLen,
|
||||
Routes: routes,
|
||||
UnsafeRoutes: unsafeRoutes,
|
||||
l: l,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"os/exec"
|
||||
"strconv"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
|
@ -17,11 +18,11 @@ type Tun struct {
|
|||
Cidr *net.IPNet
|
||||
MTU int
|
||||
UnsafeRoutes []route
|
||||
|
||||
l *logrus.Logger
|
||||
*water.Interface
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||
if len(routes) > 0 {
|
||||
return nil, fmt.Errorf("route MTU not supported in Darwin")
|
||||
}
|
||||
|
@ -31,10 +32,11 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
|
|||
Cidr: cidr,
|
||||
MTU: defaultMTU,
|
||||
UnsafeRoutes: unsafeRoutes,
|
||||
l: l,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||
}
|
||||
|
||||
|
|
|
@ -9,24 +9,23 @@ import (
|
|||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type disabledTun struct {
|
||||
read chan []byte
|
||||
cidr *net.IPNet
|
||||
logger *log.Logger
|
||||
read chan []byte
|
||||
cidr *net.IPNet
|
||||
|
||||
// Track these metrics since we don't have the tun device to do it for us
|
||||
tx metrics.Counter
|
||||
rx metrics.Counter
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *log.Logger) *disabledTun {
|
||||
func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
||||
tun := &disabledTun{
|
||||
cidr: cidr,
|
||||
read: make(chan []byte, queueLen),
|
||||
logger: l,
|
||||
cidr: cidr,
|
||||
read: make(chan []byte, queueLen),
|
||||
l: l,
|
||||
}
|
||||
|
||||
if metricsEnabled {
|
||||
|
@ -63,8 +62,8 @@ func (t *disabledTun) Read(b []byte) (int, error) {
|
|||
}
|
||||
|
||||
t.tx.Inc(1)
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
t.logger.WithField("raw", prettyPacket(r)).Debugf("Write payload")
|
||||
if t.l.Level >= logrus.DebugLevel {
|
||||
t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
|
||||
}
|
||||
|
||||
return copy(b, r), nil
|
||||
|
@ -103,7 +102,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
|
|||
select {
|
||||
case t.read <- buf:
|
||||
default:
|
||||
t.logger.Debugf("tun_disabled: dropped ICMP Echo Reply response")
|
||||
t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response")
|
||||
}
|
||||
|
||||
return true
|
||||
|
@ -114,11 +113,11 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
|||
|
||||
// Check for ICMP Echo Request before spending time doing the full parsing
|
||||
if t.handleICMPEchoRequest(b) {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
|
||||
if t.l.Level >= logrus.DebugLevel {
|
||||
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
|
||||
}
|
||||
} else if l.Level >= logrus.DebugLevel {
|
||||
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
|
||||
} else if t.l.Level >= logrus.DebugLevel {
|
||||
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
@ -18,15 +20,16 @@ type Tun struct {
|
|||
Cidr *net.IPNet
|
||||
MTU int
|
||||
UnsafeRoutes []route
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||
if len(routes) > 0 {
|
||||
return nil, fmt.Errorf("Route MTU not supported in FreeBSD")
|
||||
}
|
||||
|
@ -41,6 +44,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
|
|||
Cidr: cidr,
|
||||
MTU: defaultMTU,
|
||||
UnsafeRoutes: unsafeRoutes,
|
||||
l: l,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -52,21 +56,21 @@ func (c *Tun) Activate() error {
|
|||
}
|
||||
|
||||
// TODO use syscalls instead of exec.Command
|
||||
l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
|
||||
c.l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
|
||||
if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
|
||||
c.l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
|
||||
if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
}
|
||||
l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
|
||||
c.l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
|
||||
if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
// Unsafe path routes
|
||||
for _, r := range c.UnsafeRoutes {
|
||||
l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
|
||||
c.l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
|
||||
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
|
||||
}
|
||||
|
|
12
tun_linux.go
12
tun_linux.go
|
@ -10,6 +10,7 @@ import (
|
|||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
@ -24,6 +25,7 @@ type Tun struct {
|
|||
TXQueueLen int
|
||||
Routes []route
|
||||
UnsafeRoutes []route
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type ifReq struct {
|
||||
|
@ -78,7 +80,7 @@ type ifreqQLEN struct {
|
|||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
|
@ -91,11 +93,12 @@ func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route,
|
|||
TXQueueLen: txQueueLen,
|
||||
Routes: routes,
|
||||
UnsafeRoutes: unsafeRoutes,
|
||||
l: l,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -131,6 +134,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
|
|||
TXQueueLen: txQueueLen,
|
||||
Routes: routes,
|
||||
UnsafeRoutes: unsafeRoutes,
|
||||
l: l,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -233,14 +237,14 @@ func (c Tun) Activate() error {
|
|||
ifm := ifreqMTU{Name: devName, MTU: int32(c.MaxMTU)}
|
||||
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
||||
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
|
||||
l.WithError(err).Error("Failed to set tun mtu")
|
||||
c.l.WithError(err).Error("Failed to set tun mtu")
|
||||
}
|
||||
|
||||
// Set the transmit queue length
|
||||
ifrq := ifreqQLEN{Name: devName, Value: int32(c.TXQueueLen)}
|
||||
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
||||
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
||||
l.WithError(err).Error("Failed to set tun tx queue length")
|
||||
c.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||
}
|
||||
|
||||
// Bring up the interface
|
||||
|
|
|
@ -9,7 +9,8 @@ import (
|
|||
)
|
||||
|
||||
func Test_parseRoutes(t *testing.T) {
|
||||
c := NewConfig()
|
||||
l := NewTestLogger()
|
||||
c := NewConfig(l)
|
||||
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
||||
|
||||
// test no routes config
|
||||
|
@ -104,7 +105,8 @@ func Test_parseRoutes(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_parseUnsafeRoutes(t *testing.T) {
|
||||
c := NewConfig()
|
||||
l := NewTestLogger()
|
||||
c := NewConfig(l)
|
||||
_, n, _ := net.ParseCIDR("10.0.0.0/24")
|
||||
|
||||
// test no routes config
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"os/exec"
|
||||
"strconv"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
|
@ -15,15 +16,16 @@ type Tun struct {
|
|||
Cidr *net.IPNet
|
||||
MTU int
|
||||
UnsafeRoutes []route
|
||||
l *logrus.Logger
|
||||
|
||||
*water.Interface
|
||||
}
|
||||
|
||||
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||
}
|
||||
|
||||
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
|
||||
if len(routes) > 0 {
|
||||
return nil, fmt.Errorf("route MTU not supported in Windows")
|
||||
}
|
||||
|
@ -33,6 +35,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
|
|||
Cidr: cidr,
|
||||
MTU: defaultMTU,
|
||||
UnsafeRoutes: unsafeRoutes,
|
||||
l: l,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
// +build !e2e_testing
|
||||
|
||||
package nebula
|
||||
|
||||
import (
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
// +build !e2e_testing
|
||||
|
||||
package nebula
|
||||
|
||||
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
// +build !e2e_testing
|
||||
|
||||
package nebula
|
||||
|
||||
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// +build !linux android
|
||||
// +build !e2e_testing
|
||||
|
||||
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
|
||||
// means it can be used on platforms like Darwin and Windows.
|
||||
|
@ -9,20 +10,23 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type udpConn struct {
|
||||
*net.UDPConn
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func NewListener(ip string, port int, multi bool) (*udpConn, error) {
|
||||
func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
|
||||
lc := NewListenConfig(multi)
|
||||
pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if uc, ok := pc.(*net.UDPConn); ok {
|
||||
return &udpConn{UDPConn: uc}, nil
|
||||
return &udpConn{UDPConn: uc, l: l}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
|
||||
}
|
||||
|
@ -76,13 +80,13 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
|
|||
// Just read one packet at a time
|
||||
n, rua, err := u.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to read packets")
|
||||
f.l.WithError(err).Error("Failed to read packets")
|
||||
continue
|
||||
}
|
||||
|
||||
udpAddr.IP = rua.IP
|
||||
udpAddr.Port = uint16(rua.Port)
|
||||
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get())
|
||||
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get(f.l))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
23
udp_linux.go
23
udp_linux.go
|
@ -1,4 +1,5 @@
|
|||
// +build !android
|
||||
// +build !e2e_testing
|
||||
|
||||
package nebula
|
||||
|
||||
|
@ -10,6 +11,7 @@ import (
|
|||
"unsafe"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
|
@ -17,6 +19,7 @@ import (
|
|||
|
||||
type udpConn struct {
|
||||
sysFd int
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
var x int
|
||||
|
@ -38,7 +41,7 @@ const (
|
|||
|
||||
type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
|
||||
|
||||
func NewListener(ip string, port int, multi bool) (*udpConn, error) {
|
||||
func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
|
||||
syscall.ForkLock.RLock()
|
||||
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
|
||||
if err == nil {
|
||||
|
@ -70,7 +73,7 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
|
|||
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
|
||||
//l.Println(v, err)
|
||||
|
||||
return &udpConn{sysFd: fd}, err
|
||||
return &udpConn{sysFd: fd, l: l}, err
|
||||
}
|
||||
|
||||
func (u *udpConn) Rebind() error {
|
||||
|
@ -153,7 +156,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
|
|||
for {
|
||||
n, err := read(msgs)
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Failed to read packets")
|
||||
u.l.WithError(err).Error("Failed to read packets")
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -161,7 +164,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
|
|||
for i := 0; i < n; i++ {
|
||||
udpAddr.IP = names[i][8:24]
|
||||
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
|
||||
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get())
|
||||
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -244,12 +247,12 @@ func (u *udpConn) reloadConfig(c *Config) {
|
|||
if err == nil {
|
||||
s, err := u.GetRecvBuffer()
|
||||
if err == nil {
|
||||
l.WithField("size", s).Info("listen.read_buffer was set")
|
||||
u.l.WithField("size", s).Info("listen.read_buffer was set")
|
||||
} else {
|
||||
l.WithError(err).Warn("Failed to get listen.read_buffer")
|
||||
u.l.WithError(err).Warn("Failed to get listen.read_buffer")
|
||||
}
|
||||
} else {
|
||||
l.WithError(err).Error("Failed to set listen.read_buffer")
|
||||
u.l.WithError(err).Error("Failed to set listen.read_buffer")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -259,12 +262,12 @@ func (u *udpConn) reloadConfig(c *Config) {
|
|||
if err == nil {
|
||||
s, err := u.GetSendBuffer()
|
||||
if err == nil {
|
||||
l.WithField("size", s).Info("listen.write_buffer was set")
|
||||
u.l.WithField("size", s).Info("listen.write_buffer was set")
|
||||
} else {
|
||||
l.WithError(err).Warn("Failed to get listen.write_buffer")
|
||||
u.l.WithError(err).Warn("Failed to get listen.write_buffer")
|
||||
}
|
||||
} else {
|
||||
l.WithError(err).Error("Failed to set listen.write_buffer")
|
||||
u.l.WithError(err).Error("Failed to set listen.write_buffer")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
// +build linux
|
||||
// +build 386 amd64p32 arm mips mipsle
|
||||
// +build !android
|
||||
// +build !e2e_testing
|
||||
|
||||
package nebula
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
// +build linux
|
||||
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
|
||||
// +build !android
|
||||
// +build !e2e_testing
|
||||
|
||||
package nebula
|
||||
|
||||
|
|
Loading…
Reference in New Issue