diff --git a/connection_manager.go b/connection_manager.go index a80a6c3..82167ea 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -19,12 +19,12 @@ type connectionManager struct { inLock *sync.RWMutex out map[iputil.VpnIp]struct{} outLock *sync.RWMutex - TrafficTimer *SystemTimerWheel + TrafficTimer *LockingTimerWheel[iputil.VpnIp] intf *Interface pendingDeletion map[iputil.VpnIp]int pendingDeletionLock *sync.RWMutex - pendingDeletionTimer *SystemTimerWheel + pendingDeletionTimer *LockingTimerWheel[iputil.VpnIp] checkInterval int pendingDeletionInterval int @@ -40,11 +40,11 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface inLock: &sync.RWMutex{}, out: make(map[iputil.VpnIp]struct{}), outLock: &sync.RWMutex{}, - TrafficTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60), + TrafficTimer: NewLockingTimerWheel[iputil.VpnIp](time.Millisecond*500, time.Second*60), intf: intf, pendingDeletion: make(map[iputil.VpnIp]int), pendingDeletionLock: &sync.RWMutex{}, - pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60), + pendingDeletionTimer: NewLockingTimerWheel[iputil.VpnIp](time.Millisecond*500, time.Second*60), checkInterval: checkInterval, pendingDeletionInterval: pendingDeletionInterval, l: l, @@ -160,15 +160,13 @@ func (n *connectionManager) Run(ctx context.Context) { } func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) { - n.TrafficTimer.advance(now) + n.TrafficTimer.Advance(now) for { - ep := n.TrafficTimer.Purge() - if ep == nil { + vpnIp, has := n.TrafficTimer.Purge() + if !has { break } - vpnIp := ep.(iputil.VpnIp) - // Check for traffic coming back in from this host. traf := n.CheckIn(vpnIp) @@ -214,15 +212,13 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) } func (n *connectionManager) HandleDeletionTick(now time.Time) { - n.pendingDeletionTimer.advance(now) + n.pendingDeletionTimer.Advance(now) for { - ep := n.pendingDeletionTimer.Purge() - if ep == nil { + vpnIp, has := n.pendingDeletionTimer.Purge() + if !has { break } - vpnIp := ep.(iputil.VpnIp) - hostinfo, err := n.hostMap.QueryVpnIp(vpnIp) if err != nil { n.l.Debugf("Not found in hostmap: %s", vpnIp) diff --git a/firewall.go b/firewall.go index 99b18f8..9fd75fc 100644 --- a/firewall.go +++ b/firewall.go @@ -77,7 +77,7 @@ type FirewallConntrack struct { sync.Mutex Conns map[firewall.Packet]*conn - TimerWheel *TimerWheel + TimerWheel *TimerWheel[firewall.Packet] } type FirewallTable struct { @@ -145,7 +145,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D return &Firewall{ Conntrack: &FirewallConntrack{ Conns: make(map[firewall.Packet]*conn), - TimerWheel: NewTimerWheel(min, max), + TimerWheel: NewTimerWheel[firewall.Packet](min, max), }, InRules: newFirewallTable(), OutRules: newFirewallTable(), @@ -510,6 +510,7 @@ func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) { conntrack := f.Conntrack conntrack.Lock() if _, ok := conntrack.Conns[fp]; !ok { + conntrack.TimerWheel.Advance(time.Now()) conntrack.TimerWheel.Add(fp, timeout) } @@ -537,6 +538,7 @@ func (f *Firewall) evict(p firewall.Packet) { // Timeout is in the future, re-add the timer if newT > 0 { + conntrack.TimerWheel.Advance(time.Now()) conntrack.TimerWheel.Add(p, newT) return } diff --git a/handshake_manager.go b/handshake_manager.go index 4cb9c39..4325841 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -47,7 +47,7 @@ type HandshakeManager struct { lightHouse *LightHouse outside *udp.Conn config HandshakeConfig - OutboundHandshakeTimer *SystemTimerWheel + OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp] messageMetrics *MessageMetrics metricInitiated metrics.Counter metricTimedOut metrics.Counter @@ -65,7 +65,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [ outside: outside, config: config, trigger: make(chan iputil.VpnIp, config.triggerBuffer), - OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)), + OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), messageMetrics: config.messageMetrics, metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), @@ -90,13 +90,12 @@ func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) { } func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) { - c.OutboundHandshakeTimer.advance(now) + c.OutboundHandshakeTimer.Advance(now) for { - ep := c.OutboundHandshakeTimer.Purge() - if ep == nil { + vpnIp, has := c.OutboundHandshakeTimer.Purge() + if !has { break } - vpnIp := ep.(iputil.VpnIp) c.handleOutbound(vpnIp, f, false) } } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 5381b23..413a50a 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -106,8 +106,8 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) { assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer)) } -func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) { - for _, i := range tw.wheel { +func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { + for _, i := range tw.t.wheel { n := i.Head for n != nil { c++ diff --git a/timeout.go b/timeout.go index 6d8f68b..c1b4c39 100644 --- a/timeout.go +++ b/timeout.go @@ -1,17 +1,14 @@ package nebula import ( + "sync" "time" - - "github.com/slackhq/nebula/firewall" ) // How many timer objects should be cached const timerCacheMax = 50000 -var emptyFWPacket = firewall.Packet{} - -type TimerWheel struct { +type TimerWheel[T any] struct { // Current tick current int @@ -26,31 +23,38 @@ type TimerWheel struct { wheelDuration time.Duration // The actual wheel which is just a set of singly linked lists, head/tail pointers - wheel []*TimeoutList + wheel []*TimeoutList[T] // Singly linked list of items that have timed out of the wheel - expired *TimeoutList + expired *TimeoutList[T] // Item cache to avoid garbage collect - itemCache *TimeoutItem + itemCache *TimeoutItem[T] itemsCached int } +type LockingTimerWheel[T any] struct { + m sync.Mutex + t *TimerWheel[T] +} + // TimeoutList Represents a tick in the wheel -type TimeoutList struct { - Head *TimeoutItem - Tail *TimeoutItem +type TimeoutList[T any] struct { + Head *TimeoutItem[T] + Tail *TimeoutItem[T] } // TimeoutItem Represents an item within a tick -type TimeoutItem struct { - Packet firewall.Packet - Next *TimeoutItem +type TimeoutItem[T any] struct { + Item T + Next *TimeoutItem[T] } // NewTimerWheel Builds a timer wheel and identifies the tick duration and wheel duration from the provided values // Purge must be called once per entry to actually remove anything -func NewTimerWheel(min, max time.Duration) *TimerWheel { +// The TimerWheel does not handle concurrency on its own. +// Locks around access to it must be used if multiple routines are manipulating it. +func NewTimerWheel[T any](min, max time.Duration) *TimerWheel[T] { //TODO provide an error //if min >= max { // return nil @@ -61,26 +65,31 @@ func NewTimerWheel(min, max time.Duration) *TimerWheel { // timeout wLen := int((max / min) + 2) - tw := TimerWheel{ + tw := TimerWheel[T]{ wheelLen: wLen, - wheel: make([]*TimeoutList, wLen), + wheel: make([]*TimeoutList[T], wLen), tickDuration: min, wheelDuration: max, - expired: &TimeoutList{}, + expired: &TimeoutList[T]{}, } for i := range tw.wheel { - tw.wheel[i] = &TimeoutList{} + tw.wheel[i] = &TimeoutList[T]{} } return &tw } -// Add will add a firewall.Packet to the wheel in it's proper timeout -func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem { - // Check and see if we should progress the tick - tw.advance(time.Now()) +// NewLockingTimerWheel is version of TimerWheel that is safe for concurrent use with a small performance penalty +func NewLockingTimerWheel[T any](min, max time.Duration) *LockingTimerWheel[T] { + return &LockingTimerWheel[T]{ + t: NewTimerWheel[T](min, max), + } +} +// Add will add an item to the wheel in its proper timeout. +// Caller should Advance the wheel prior to ensure the proper slot is used. +func (tw *TimerWheel[T]) Add(v T, timeout time.Duration) *TimeoutItem[T] { i := tw.findWheel(timeout) // Try to fetch off the cache @@ -90,11 +99,11 @@ func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem tw.itemsCached-- ti.Next = nil } else { - ti = &TimeoutItem{} + ti = &TimeoutItem[T]{} } // Relink and return - ti.Packet = v + ti.Item = v if tw.wheel[i].Tail == nil { tw.wheel[i].Head = ti tw.wheel[i].Tail = ti @@ -106,9 +115,12 @@ func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem return ti } -func (tw *TimerWheel) Purge() (firewall.Packet, bool) { +// Purge removes and returns the first available expired item from the wheel and the 2nd argument is true. +// If no item is available then an empty T is returned and the 2nd argument is false. +func (tw *TimerWheel[T]) Purge() (T, bool) { if tw.expired.Head == nil { - return emptyFWPacket, false + var na T + return na, false } ti := tw.expired.Head @@ -128,11 +140,11 @@ func (tw *TimerWheel) Purge() (firewall.Packet, bool) { tw.itemsCached++ } - return ti.Packet, true + return ti.Item, true } -// advance will move the wheel forward by proper number of ticks. The caller _should_ lock the wheel before calling this -func (tw *TimerWheel) findWheel(timeout time.Duration) (i int) { +// findWheel find the next position in the wheel for the provided timeout given the current tick +func (tw *TimerWheel[T]) findWheel(timeout time.Duration) (i int) { if timeout < tw.tickDuration { // Can't track anything below the set resolution timeout = tw.tickDuration @@ -154,8 +166,9 @@ func (tw *TimerWheel) findWheel(timeout time.Duration) (i int) { return tick } -// advance will lock and move the wheel forward by proper number of ticks. -func (tw *TimerWheel) advance(now time.Time) { +// Advance will move the wheel forward by the appropriate number of ticks for the provided time and all items +// passed over will be moved to the expired list. Calling Purge is necessary to remove them entirely. +func (tw *TimerWheel[T]) Advance(now time.Time) { if tw.lastTick == nil { tw.lastTick = &now } @@ -192,3 +205,21 @@ func (tw *TimerWheel) advance(now time.Time) { newTick := tw.lastTick.Add(tw.tickDuration * time.Duration(adv)) tw.lastTick = &newTick } + +func (lw *LockingTimerWheel[T]) Add(v T, timeout time.Duration) *TimeoutItem[T] { + lw.m.Lock() + defer lw.m.Unlock() + return lw.t.Add(v, timeout) +} + +func (lw *LockingTimerWheel[T]) Purge() (T, bool) { + lw.m.Lock() + defer lw.m.Unlock() + return lw.t.Purge() +} + +func (lw *LockingTimerWheel[T]) Advance(now time.Time) { + lw.m.Lock() + defer lw.m.Unlock() + lw.t.Advance(now) +} diff --git a/timeout_system.go b/timeout_system.go deleted file mode 100644 index c39d9cd..0000000 --- a/timeout_system.go +++ /dev/null @@ -1,199 +0,0 @@ -package nebula - -import ( - "sync" - "time" - - "github.com/slackhq/nebula/iputil" -) - -// How many timer objects should be cached -const systemTimerCacheMax = 50000 - -type SystemTimerWheel struct { - // Current tick - current int - - // Cheat on finding the length of the wheel - wheelLen int - - // Last time we ticked, since we are lazy ticking - lastTick *time.Time - - // Durations of a tick and the entire wheel - tickDuration time.Duration - wheelDuration time.Duration - - // The actual wheel which is just a set of singly linked lists, head/tail pointers - wheel []*SystemTimeoutList - - // Singly linked list of items that have timed out of the wheel - expired *SystemTimeoutList - - // Item cache to avoid garbage collect - itemCache *SystemTimeoutItem - itemsCached int - - lock sync.Mutex -} - -// SystemTimeoutList Represents a tick in the wheel -type SystemTimeoutList struct { - Head *SystemTimeoutItem - Tail *SystemTimeoutItem -} - -// SystemTimeoutItem Represents an item within a tick -type SystemTimeoutItem struct { - Item iputil.VpnIp - Next *SystemTimeoutItem -} - -// NewSystemTimerWheel Builds a timer wheel and identifies the tick duration and wheel duration from the provided values -// Purge must be called once per entry to actually remove anything -func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel { - //TODO provide an error - //if min >= max { - // return nil - //} - - // Round down and add 2 so we can have the smallest # of ticks in the wheel and still account for a full - // max duration, even if our current tick is at the maximum position and the next item to be added is at maximum - // timeout - wLen := int((max / min) + 2) - - tw := SystemTimerWheel{ - wheelLen: wLen, - wheel: make([]*SystemTimeoutList, wLen), - tickDuration: min, - wheelDuration: max, - expired: &SystemTimeoutList{}, - } - - for i := range tw.wheel { - tw.wheel[i] = &SystemTimeoutList{} - } - - return &tw -} - -func (tw *SystemTimerWheel) Add(v iputil.VpnIp, timeout time.Duration) *SystemTimeoutItem { - tw.lock.Lock() - defer tw.lock.Unlock() - - // Check and see if we should progress the tick - //tw.advance(time.Now()) - - i := tw.findWheel(timeout) - - // Try to fetch off the cache - ti := tw.itemCache - if ti != nil { - tw.itemCache = ti.Next - ti.Next = nil - tw.itemsCached-- - } else { - ti = &SystemTimeoutItem{} - } - - // Relink and return - ti.Item = v - ti.Next = tw.wheel[i].Head - tw.wheel[i].Head = ti - - if tw.wheel[i].Tail == nil { - tw.wheel[i].Tail = ti - } - - return ti -} - -func (tw *SystemTimerWheel) Purge() interface{} { - tw.lock.Lock() - defer tw.lock.Unlock() - - if tw.expired.Head == nil { - return nil - } - - ti := tw.expired.Head - tw.expired.Head = ti.Next - - if tw.expired.Head == nil { - tw.expired.Tail = nil - } - - p := ti.Item - - // Clear out the items references - ti.Item = 0 - ti.Next = nil - - // Maybe cache it for later - if tw.itemsCached < systemTimerCacheMax { - ti.Next = tw.itemCache - tw.itemCache = ti - tw.itemsCached++ - } - - return p -} - -func (tw *SystemTimerWheel) findWheel(timeout time.Duration) (i int) { - if timeout < tw.tickDuration { - // Can't track anything below the set resolution - timeout = tw.tickDuration - } else if timeout > tw.wheelDuration { - // We aren't handling timeouts greater than the wheels duration - timeout = tw.wheelDuration - } - - // Find the next highest, rounding up - tick := int(((timeout - 1) / tw.tickDuration) + 1) - - // Add another tick since the current tick may almost be over then map it to the wheel from our - // current position - tick += tw.current + 1 - if tick >= tw.wheelLen { - tick -= tw.wheelLen - } - - return tick -} - -func (tw *SystemTimerWheel) advance(now time.Time) { - tw.lock.Lock() - defer tw.lock.Unlock() - - if tw.lastTick == nil { - tw.lastTick = &now - } - - // We want to round down - ticks := int(now.Sub(*tw.lastTick) / tw.tickDuration) - //l.Infoln("Ticks: ", ticks) - for i := 0; i < ticks; i++ { - tw.current++ - //l.Infoln("Tick: ", tw.current) - if tw.current >= tw.wheelLen { - tw.current = 0 - } - - // We need to append the expired items as to not starve evicting the oldest ones - if tw.expired.Tail == nil { - tw.expired.Head = tw.wheel[tw.current].Head - tw.expired.Tail = tw.wheel[tw.current].Tail - } else { - tw.expired.Tail.Next = tw.wheel[tw.current].Head - if tw.wheel[tw.current].Tail != nil { - tw.expired.Tail = tw.wheel[tw.current].Tail - } - } - - //l.Infoln("Head: ", tw.expired.Head, "Tail: ", tw.expired.Tail) - tw.wheel[tw.current].Head = nil - tw.wheel[tw.current].Tail = nil - - tw.lastTick = &now - } -} diff --git a/timeout_system_test.go b/timeout_system_test.go deleted file mode 100644 index ba3c22b..0000000 --- a/timeout_system_test.go +++ /dev/null @@ -1,156 +0,0 @@ -package nebula - -import ( - "net" - "testing" - "time" - - "github.com/slackhq/nebula/iputil" - "github.com/stretchr/testify/assert" -) - -func TestNewSystemTimerWheel(t *testing.T) { - // Make sure we get an object we expect - tw := NewSystemTimerWheel(time.Second, time.Second*10) - assert.Equal(t, 12, tw.wheelLen) - assert.Equal(t, 0, tw.current) - assert.Nil(t, tw.lastTick) - assert.Equal(t, time.Second*1, tw.tickDuration) - assert.Equal(t, time.Second*10, tw.wheelDuration) - assert.Len(t, tw.wheel, 12) - - // Assert the math is correct - tw = NewSystemTimerWheel(time.Second*3, time.Second*10) - assert.Equal(t, 5, tw.wheelLen) - - tw = NewSystemTimerWheel(time.Second*120, time.Minute*10) - assert.Equal(t, 7, tw.wheelLen) -} - -func TestSystemTimerWheel_findWheel(t *testing.T) { - tw := NewSystemTimerWheel(time.Second, time.Second*10) - assert.Len(t, tw.wheel, 12) - - // Current + tick + 1 since we don't know how far into current we are - assert.Equal(t, 2, tw.findWheel(time.Second*1)) - - // Scale up to min duration - assert.Equal(t, 2, tw.findWheel(time.Millisecond*1)) - - // Make sure we hit that last index - assert.Equal(t, 11, tw.findWheel(time.Second*10)) - - // Scale down to max duration - assert.Equal(t, 11, tw.findWheel(time.Second*11)) - - tw.current = 1 - // Make sure we account for the current position properly - assert.Equal(t, 3, tw.findWheel(time.Second*1)) - assert.Equal(t, 0, tw.findWheel(time.Second*10)) - - // Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel - for min := time.Duration(1); min < 100; min++ { - for max := min; max < 100; max++ { - tw = NewSystemTimerWheel(min, max) - - for current := 0; current < tw.wheelLen; current++ { - tw.current = current - for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ { - tick := tw.findWheel(timeout) - if tick >= tw.wheelLen { - t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick) - } - } - } - } - } -} - -func TestSystemTimerWheel_Add(t *testing.T) { - tw := NewSystemTimerWheel(time.Second, time.Second*10) - - fp1 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4")) - tw.Add(fp1, time.Second*1) - - // Make sure we set head and tail properly - assert.NotNil(t, tw.wheel[2]) - assert.Equal(t, fp1, tw.wheel[2].Head.Item) - assert.Nil(t, tw.wheel[2].Head.Next) - assert.Equal(t, fp1, tw.wheel[2].Tail.Item) - assert.Nil(t, tw.wheel[2].Tail.Next) - - // Make sure we only modify head - fp2 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4")) - tw.Add(fp2, time.Second*1) - assert.Equal(t, fp2, tw.wheel[2].Head.Item) - assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item) - assert.Equal(t, fp1, tw.wheel[2].Tail.Item) - assert.Nil(t, tw.wheel[2].Tail.Next) - - // Make sure we use free'd items first - tw.itemCache = &SystemTimeoutItem{} - tw.itemsCached = 1 - tw.Add(fp2, time.Second*1) - assert.Nil(t, tw.itemCache) - assert.Equal(t, 0, tw.itemsCached) -} - -func TestSystemTimerWheel_Purge(t *testing.T) { - // First advance should set the lastTick and do nothing else - tw := NewSystemTimerWheel(time.Second, time.Second*10) - assert.Nil(t, tw.lastTick) - tw.advance(time.Now()) - assert.NotNil(t, tw.lastTick) - assert.Equal(t, 0, tw.current) - - fps := []iputil.VpnIp{9, 10, 11, 12} - - //fp1 := ip2int(net.ParseIP("1.2.3.4")) - - tw.Add(fps[0], time.Second*1) - tw.Add(fps[1], time.Second*1) - tw.Add(fps[2], time.Second*2) - tw.Add(fps[3], time.Second*2) - - ta := time.Now().Add(time.Second * 3) - lastTick := *tw.lastTick - tw.advance(ta) - assert.Equal(t, 3, tw.current) - assert.True(t, tw.lastTick.After(lastTick)) - - // Make sure we get all 4 packets back - for i := 0; i < 4; i++ { - assert.Contains(t, fps, tw.Purge()) - } - - // Make sure there aren't any leftover - assert.Nil(t, tw.Purge()) - assert.Nil(t, tw.expired.Head) - assert.Nil(t, tw.expired.Tail) - - // Make sure we cached the free'd items - assert.Equal(t, 4, tw.itemsCached) - ci := tw.itemCache - for i := 0; i < 4; i++ { - assert.NotNil(t, ci) - ci = ci.Next - } - assert.Nil(t, ci) - - // Lets make sure we roll over properly - ta = ta.Add(time.Second * 5) - tw.advance(ta) - assert.Equal(t, 8, tw.current) - - ta = ta.Add(time.Second * 2) - tw.advance(ta) - assert.Equal(t, 10, tw.current) - - ta = ta.Add(time.Second * 1) - tw.advance(ta) - assert.Equal(t, 11, tw.current) - - ta = ta.Add(time.Second * 1) - tw.advance(ta) - assert.Equal(t, 0, tw.current) -} diff --git a/timeout_test.go b/timeout_test.go index 70b107c..3f81ff4 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -10,7 +10,7 @@ import ( func TestNewTimerWheel(t *testing.T) { // Make sure we get an object we expect - tw := NewTimerWheel(time.Second, time.Second*10) + tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) assert.Equal(t, 12, tw.wheelLen) assert.Equal(t, 0, tw.current) assert.Nil(t, tw.lastTick) @@ -19,15 +19,27 @@ func TestNewTimerWheel(t *testing.T) { assert.Len(t, tw.wheel, 12) // Assert the math is correct - tw = NewTimerWheel(time.Second*3, time.Second*10) + tw = NewTimerWheel[firewall.Packet](time.Second*3, time.Second*10) assert.Equal(t, 5, tw.wheelLen) - tw = NewTimerWheel(time.Second*120, time.Minute*10) + tw = NewTimerWheel[firewall.Packet](time.Second*120, time.Minute*10) assert.Equal(t, 7, tw.wheelLen) + + // Test empty purge of non nil items + i, ok := tw.Purge() + assert.Equal(t, firewall.Packet{}, i) + assert.False(t, ok) + + // Test empty purges of nil items + tw2 := NewTimerWheel[*int](time.Second, time.Second*10) + i2, ok := tw2.Purge() + assert.Nil(t, i2) + assert.False(t, ok) + } func TestTimerWheel_findWheel(t *testing.T) { - tw := NewTimerWheel(time.Second, time.Second*10) + tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) assert.Len(t, tw.wheel, 12) // Current + tick + 1 since we don't know how far into current we are @@ -49,28 +61,28 @@ func TestTimerWheel_findWheel(t *testing.T) { } func TestTimerWheel_Add(t *testing.T) { - tw := NewTimerWheel(time.Second, time.Second*10) + tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) fp1 := firewall.Packet{} tw.Add(fp1, time.Second*1) // Make sure we set head and tail properly assert.NotNil(t, tw.wheel[2]) - assert.Equal(t, fp1, tw.wheel[2].Head.Packet) + assert.Equal(t, fp1, tw.wheel[2].Head.Item) assert.Nil(t, tw.wheel[2].Head.Next) - assert.Equal(t, fp1, tw.wheel[2].Tail.Packet) + assert.Equal(t, fp1, tw.wheel[2].Tail.Item) assert.Nil(t, tw.wheel[2].Tail.Next) // Make sure we only modify head fp2 := firewall.Packet{} tw.Add(fp2, time.Second*1) - assert.Equal(t, fp2, tw.wheel[2].Head.Packet) - assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet) - assert.Equal(t, fp1, tw.wheel[2].Tail.Packet) + assert.Equal(t, fp2, tw.wheel[2].Head.Item) + assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item) + assert.Equal(t, fp1, tw.wheel[2].Tail.Item) assert.Nil(t, tw.wheel[2].Tail.Next) // Make sure we use free'd items first - tw.itemCache = &TimeoutItem{} + tw.itemCache = &TimeoutItem[firewall.Packet]{} tw.itemsCached = 1 tw.Add(fp2, time.Second*1) assert.Nil(t, tw.itemCache) @@ -79,7 +91,7 @@ func TestTimerWheel_Add(t *testing.T) { // Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel for min := time.Duration(1); min < 100; min++ { for max := min; max < 100; max++ { - tw = NewTimerWheel(min, max) + tw = NewTimerWheel[firewall.Packet](min, max) for current := 0; current < tw.wheelLen; current++ { tw.current = current @@ -96,9 +108,9 @@ func TestTimerWheel_Add(t *testing.T) { func TestTimerWheel_Purge(t *testing.T) { // First advance should set the lastTick and do nothing else - tw := NewTimerWheel(time.Second, time.Second*10) + tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) assert.Nil(t, tw.lastTick) - tw.advance(time.Now()) + tw.Advance(time.Now()) assert.NotNil(t, tw.lastTick) assert.Equal(t, 0, tw.current) @@ -116,7 +128,7 @@ func TestTimerWheel_Purge(t *testing.T) { ta := time.Now().Add(time.Second * 3) lastTick := *tw.lastTick - tw.advance(ta) + tw.Advance(ta) assert.Equal(t, 3, tw.current) assert.True(t, tw.lastTick.After(lastTick)) @@ -142,20 +154,20 @@ func TestTimerWheel_Purge(t *testing.T) { } assert.Nil(t, ci) - // Lets make sure we roll over properly + // Let's make sure we roll over properly ta = ta.Add(time.Second * 5) - tw.advance(ta) + tw.Advance(ta) assert.Equal(t, 8, tw.current) ta = ta.Add(time.Second * 2) - tw.advance(ta) + tw.Advance(ta) assert.Equal(t, 10, tw.current) ta = ta.Add(time.Second * 1) - tw.advance(ta) + tw.Advance(ta) assert.Equal(t, 11, tw.current) ta = ta.Add(time.Second * 1) - tw.advance(ta) + tw.Advance(ta) assert.Equal(t, 0, tw.current) }