From 9af242dc47706c3cdd01f829e98e7c859a3704ab Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Mon, 31 Oct 2022 13:37:41 -0400 Subject: [PATCH] switch to new sync/atomic helpers in go1.19 (#728) These new helpers make the code a lot cleaner. I confirmed that the simple helpers like `atomic.Int64` don't add any extra overhead as they get inlined by the compiler. `atomic.Pointer` adds an extra method call as it no longer gets inlined, but we aren't using these on the hot path so it is probably okay. --- .github/workflows/gofmt.yml | 8 ++-- .github/workflows/release.yml | 12 +++--- .github/workflows/smoke.yml | 8 ++-- .github/workflows/test.yml | 16 ++++---- Makefile | 2 +- cmd/nebula-service/main.go | 2 +- cmd/nebula/main.go | 2 +- connection_manager_test.go | 20 ++++++++-- connection_state.go | 24 +++++------ control.go | 3 +- firewall.go | 2 +- firewall/cache.go | 6 +-- go.mod | 2 +- handshake_ix.go | 3 +- handshake_manager_test.go | 13 +----- hostmap.go | 9 ++--- inside.go | 6 +-- interface.go | 6 +-- lighthouse.go | 75 ++++++++++++++++++----------------- punchy.go | 29 +++++--------- relay_manager.go | 17 +++----- remote_list.go | 2 +- wintun/tun.go | 4 -- 23 files changed, 126 insertions(+), 145 deletions(-) diff --git a/.github/workflows/gofmt.yml b/.github/workflows/gofmt.yml index ddfca5a..a00453b 100644 --- a/.github/workflows/gofmt.yml +++ b/.github/workflows/gofmt.yml @@ -14,10 +14,10 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.18 + - name: Set up Go 1.19 uses: actions/setup-go@v2 with: - go-version: 1.18 + go-version: 1.19 id: go - name: Check out code into the Go module directory @@ -26,9 +26,9 @@ jobs: - uses: actions/cache@v2 with: path: ~/go/pkg/mod - key: ${{ runner.os }}-gofmt1.18-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-gofmt1.19-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-gofmt1.18- + ${{ runner.os }}-gofmt1.19- - name: Install goimports run: | diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7485beb..572b0ff 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,10 +10,10 @@ jobs: name: Build Linux All runs-on: ubuntu-latest steps: - - name: Set up Go 1.18 + - name: Set up Go 1.19 uses: actions/setup-go@v2 with: - go-version: 1.18 + go-version: 1.19 - name: Checkout code uses: actions/checkout@v2 @@ -34,10 +34,10 @@ jobs: name: Build Windows runs-on: windows-latest steps: - - name: Set up Go 1.18 + - name: Set up Go 1.19 uses: actions/setup-go@v2 with: - go-version: 1.18 + go-version: 1.19 - name: Checkout code uses: actions/checkout@v2 @@ -68,10 +68,10 @@ jobs: HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} runs-on: macos-11 steps: - - name: Set up Go 1.18 + - name: Set up Go 1.19 uses: actions/setup-go@v2 with: - go-version: 1.18 + go-version: 1.19 - name: Checkout code uses: actions/checkout@v2 diff --git a/.github/workflows/smoke.yml b/.github/workflows/smoke.yml index 9920992..162d526 100644 --- a/.github/workflows/smoke.yml +++ b/.github/workflows/smoke.yml @@ -18,10 +18,10 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.18 + - name: Set up Go 1.19 uses: actions/setup-go@v2 with: - go-version: 1.18 + go-version: 1.19 id: go - name: Check out code into the Go module directory @@ -30,9 +30,9 @@ jobs: - uses: actions/cache@v2 with: path: ~/go/pkg/mod - key: ${{ runner.os }}-go1.18-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go1.19-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-go1.18- + ${{ runner.os }}-go1.19- - name: build run: make bin-docker diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fb69112..69ed606 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,10 +18,10 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.18 + - name: Set up Go 1.19 uses: actions/setup-go@v2 with: - go-version: 1.18 + go-version: 1.19 id: go - name: Check out code into the Go module directory @@ -30,9 +30,9 @@ jobs: - uses: actions/cache@v2 with: path: ~/go/pkg/mod - key: ${{ runner.os }}-go1.18-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go1.19-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-go1.18- + ${{ runner.os }}-go1.19- - name: Build run: make all @@ -57,10 +57,10 @@ jobs: os: [windows-latest, macos-11] steps: - - name: Set up Go 1.18 + - name: Set up Go 1.19 uses: actions/setup-go@v2 with: - go-version: 1.18 + go-version: 1.19 id: go - name: Check out code into the Go module directory @@ -69,9 +69,9 @@ jobs: - uses: actions/cache@v2 with: path: ~/go/pkg/mod - key: ${{ runner.os }}-go1.18-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go1.19-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-go1.18- + ${{ runner.os }}-go1.19- - name: Build nebula run: go build ./cmd/nebula diff --git a/Makefile b/Makefile index 188ffea..b31c0fc 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -GOMINVERSION = 1.18 +GOMINVERSION = 1.19 NEBULA_CMD_PATH = "./cmd/nebula" GO111MODULE = on export GO111MODULE diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index f211c97..c1de267 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -13,7 +13,7 @@ import ( // A version string that can be set with // -// -ldflags "-X main.Build=SOMEVERSION" +// -ldflags "-X main.Build=SOMEVERSION" // // at compile-time. var Build string diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index efe406b..e9b285e 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -13,7 +13,7 @@ import ( // A version string that can be set with // -// -ldflags "-X main.Build=SOMEVERSION" +// -ldflags "-X main.Build=SOMEVERSION" // // at compile-time. var Build string diff --git a/connection_manager_test.go b/connection_manager_test.go index bae48e5..df42800 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -18,6 +18,20 @@ import ( var vpnIp iputil.VpnIp +func newTestLighthouse() *LightHouse { + lh := &LightHouse{ + l: test.NewLogger(), + addrMap: map[iputil.VpnIp]*RemoteList{}, + } + lighthouses := map[iputil.VpnIp]struct{}{} + staticList := map[iputil.VpnIp]struct{}{} + + lh.lighthouses.Store(&lighthouses) + lh.staticList.Store(&staticList) + + return lh +} + func Test_NewConnectionManagerTest(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") @@ -35,7 +49,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { rawCertificateNoKey: []byte{}, } - lh := &LightHouse{l: l, atomicStaticList: make(map[iputil.VpnIp]struct{}), atomicLighthouses: make(map[iputil.VpnIp]struct{})} + lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, @@ -104,7 +118,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { rawCertificateNoKey: []byte{}, } - lh := &LightHouse{l: l, atomicStaticList: make(map[iputil.VpnIp]struct{}), atomicLighthouses: make(map[iputil.VpnIp]struct{})} + lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, @@ -213,7 +227,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { rawCertificateNoKey: []byte{}, } - lh := &LightHouse{l: l, atomicStaticList: make(map[iputil.VpnIp]struct{}), atomicLighthouses: make(map[iputil.VpnIp]struct{})} + lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, diff --git a/connection_state.go b/connection_state.go index c28cc42..6bbb02f 100644 --- a/connection_state.go +++ b/connection_state.go @@ -14,17 +14,17 @@ import ( const ReplayWindow = 1024 type ConnectionState struct { - eKey *NebulaCipherState - dKey *NebulaCipherState - H *noise.HandshakeState - certState *CertState - peerCert *cert.NebulaCertificate - initiator bool - atomicMessageCounter uint64 - window *Bits - queueLock sync.Mutex - writeLock sync.Mutex - ready bool + eKey *NebulaCipherState + dKey *NebulaCipherState + H *noise.HandshakeState + certState *CertState + peerCert *cert.NebulaCertificate + initiator bool + messageCounter atomic.Uint64 + window *Bits + queueLock sync.Mutex + writeLock sync.Mutex + ready bool } func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { @@ -70,7 +70,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) { return json.Marshal(m{ "certificate": cs.peerCert, "initiator": cs.initiator, - "message_counter": atomic.LoadUint64(&cs.atomicMessageCounter), + "message_counter": cs.messageCounter.Load(), "ready": cs.ready, }) } diff --git a/control.go b/control.go index 6e7bda1..2e7ffee 100644 --- a/control.go +++ b/control.go @@ -5,7 +5,6 @@ import ( "net" "os" "os/signal" - "sync/atomic" "syscall" "github.com/sirupsen/logrus" @@ -219,7 +218,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { } if h.ConnectionState != nil { - chi.MessageCounter = atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter) + chi.MessageCounter = h.ConnectionState.messageCounter.Load() } if c := h.GetCert(); c != nil { diff --git a/firewall.go b/firewall.go index dfc7fd1..99b18f8 100644 --- a/firewall.go +++ b/firewall.go @@ -879,7 +879,7 @@ func parsePort(s string) (startPort, endPort int32, err error) { return } -//TODO: write tests for these +// TODO: write tests for these func setTCPRTTTracking(c *conn, p []byte) { if c.Seq != 0 { return diff --git a/firewall/cache.go b/firewall/cache.go index 5560ab2..71b83f4 100644 --- a/firewall/cache.go +++ b/firewall/cache.go @@ -13,7 +13,7 @@ type ConntrackCache map[Packet]struct{} type ConntrackCacheTicker struct { cacheV uint64 - cacheTick uint64 + cacheTick atomic.Uint64 cache ConntrackCache } @@ -35,7 +35,7 @@ func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker { func (c *ConntrackCacheTicker) tick(d time.Duration) { for { time.Sleep(d) - atomic.AddUint64(&c.cacheTick, 1) + c.cacheTick.Add(1) } } @@ -45,7 +45,7 @@ func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache { if c == nil { return nil } - if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV { + if tick := c.cacheTick.Load(); tick != c.cacheV { c.cacheV = tick if ll := len(c.cache); ll > 0 { if l.Level == logrus.DebugLevel { diff --git a/go.mod b/go.mod index 69bb424..5e7393e 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/slackhq/nebula -go 1.18 +go 1.19 require ( github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be diff --git a/handshake_ix.go b/handshake_ix.go index fd1a908..11a16a6 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -1,7 +1,6 @@ package nebula import ( - "sync/atomic" "time" "github.com/flynn/noise" @@ -51,7 +50,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { } h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) - atomic.AddUint64(&ci.atomicMessageCounter, 1) + ci.messageCounter.Add(1) msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { diff --git a/handshake_manager_test.go b/handshake_manager_test.go index ae8b267..5381b23 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -21,11 +21,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { preferredRanges := []*net.IPNet{localrange} mw := &mockEncWriter{} mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - lh := &LightHouse{ - atomicStaticList: make(map[iputil.VpnIp]struct{}), - atomicLighthouses: make(map[iputil.VpnIp]struct{}), - addrMap: make(map[iputil.VpnIp]*RemoteList), - } + lh := newTestLighthouse() blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig) @@ -79,12 +75,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) { preferredRanges := []*net.IPNet{localrange} mw := &mockEncWriter{} mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) - lh := &LightHouse{ - addrMap: make(map[iputil.VpnIp]*RemoteList), - l: l, - atomicStaticList: make(map[iputil.VpnIp]struct{}), - atomicLighthouses: make(map[iputil.VpnIp]struct{}), - } + lh := newTestLighthouse() blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig) diff --git a/hostmap.go b/hostmap.go index 402c1a8..84b2041 100644 --- a/hostmap.go +++ b/hostmap.go @@ -18,7 +18,7 @@ import ( "github.com/slackhq/nebula/udp" ) -//const ProbeLen = 100 +// const ProbeLen = 100 const PromoteEvery = 1000 const ReQueryEvery = 5000 const MaxRemotes = 10 @@ -153,7 +153,7 @@ type HostInfo struct { remote *udp.Addr remotes *RemoteList - promoteCounter uint32 + promoteCounter atomic.Uint32 ConnectionState *ConnectionState handshakeStart time.Time //todo: this an entry in the handshake manager HandshakeReady bool //todo: being in the manager means you are ready @@ -284,7 +284,6 @@ func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) ( if h, ok := hm.Hosts[vpnIp]; !ok { hm.RUnlock() h = &HostInfo{ - promoteCounter: 0, vpnIp: vpnIp, HandshakePacket: make(map[uint8][]byte, 0), relayState: RelayState{ @@ -591,7 +590,7 @@ func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) { // TryPromoteBest handles re-querying lighthouses and probing for better paths // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { - c := atomic.AddUint32(&i.promoteCounter, 1) + c := i.promoteCounter.Add(1) if c%PromoteEvery == 0 { // The lock here is currently protecting i.remote access i.RLock() @@ -658,7 +657,7 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) { i.HandshakeComplete = true //TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen. // Clamping it to 2 gets us out of the woods for now - atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2) + i.ConnectionState.messageCounter.Store(2) if l.Level >= logrus.DebugLevel { i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore)) diff --git a/inside.go b/inside.go index 39c5a43..177bcd3 100644 --- a/inside.go +++ b/inside.go @@ -1,8 +1,6 @@ package nebula import ( - "sync/atomic" - "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" @@ -222,7 +220,7 @@ func (f *Interface) SendVia(viaIfc interface{}, ) { via := viaIfc.(*HostInfo) relay := relayIfc.(*Relay) - c := atomic.AddUint64(&via.ConnectionState.atomicMessageCounter, 1) + c := via.ConnectionState.messageCounter.Add(1) out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c) f.connectionManager.Out(via.vpnIp) @@ -281,7 +279,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType //TODO: enable if we do more than 1 tun queue //ci.writeLock.Lock() - c := atomic.AddUint64(&ci.atomicMessageCounter, 1) + c := ci.messageCounter.Add(1) //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c) diff --git a/interface.go b/interface.go index a84eb7f..632e823 100644 --- a/interface.go +++ b/interface.go @@ -67,7 +67,7 @@ type Interface struct { routines int caPool *cert.NebulaCAPool disconnectInvalid bool - closed int32 + closed atomic.Bool relayManager *relayManager sendRecvErrorConfig sendRecvErrorConfig @@ -253,7 +253,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { for { n, err := reader.Read(packet) if err != nil { - if errors.Is(err, os.ErrClosed) && atomic.LoadInt32(&f.closed) != 0 { + if errors.Is(err, os.ErrClosed) && f.closed.Load() { return } @@ -391,7 +391,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { } func (f *Interface) Close() error { - atomic.StoreInt32(&f.closed, 1) + f.closed.Store(true) // Release the tun device return f.inside.Close() diff --git a/lighthouse.go b/lighthouse.go index 4987603..60e1f29 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -9,7 +9,6 @@ import ( "sync" "sync/atomic" "time" - "unsafe" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" @@ -49,29 +48,29 @@ type LightHouse struct { // respond with. // - When we are not a lighthouse, this filters which addresses we accept // from lighthouses. - atomicRemoteAllowList *RemoteAllowList + remoteAllowList atomic.Pointer[RemoteAllowList] // filters local addresses that we advertise to lighthouses - atomicLocalAllowList *LocalAllowList + localAllowList atomic.Pointer[LocalAllowList] // used to trigger the HandshakeManager when we receive HostQueryReply handshakeTrigger chan<- iputil.VpnIp - // atomicStaticList exists to avoid having a bool in each addrMap entry + // staticList exists to avoid having a bool in each addrMap entry // since static should be rare - atomicStaticList map[iputil.VpnIp]struct{} - atomicLighthouses map[iputil.VpnIp]struct{} + staticList atomic.Pointer[map[iputil.VpnIp]struct{}] + lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}] - atomicInterval int64 + interval atomic.Int64 updateCancel context.CancelFunc updateParentCtx context.Context updateUdp udp.EncWriter nebulaPort uint32 // 32 bits because protobuf does not have a uint16 - atomicAdvertiseAddrs []netIpAndPort + advertiseAddrs atomic.Pointer[[]netIpAndPort] // IP's of relays that can be used by peers to access me - atomicRelaysForMe []iputil.VpnIp + relaysForMe atomic.Pointer[[]iputil.VpnIp] metrics *MessageMetrics metricHolepunchTx metrics.Counter @@ -98,18 +97,20 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, ones, _ := myVpnNet.Mask.Size() h := LightHouse{ - amLighthouse: amLighthouse, - myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP), - myVpnZeros: iputil.VpnIp(32 - ones), - myVpnNet: myVpnNet, - addrMap: make(map[iputil.VpnIp]*RemoteList), - nebulaPort: nebulaPort, - atomicLighthouses: make(map[iputil.VpnIp]struct{}), - atomicStaticList: make(map[iputil.VpnIp]struct{}), - punchConn: pc, - punchy: p, - l: l, + amLighthouse: amLighthouse, + myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP), + myVpnZeros: iputil.VpnIp(32 - ones), + myVpnNet: myVpnNet, + addrMap: make(map[iputil.VpnIp]*RemoteList), + nebulaPort: nebulaPort, + punchConn: pc, + punchy: p, + l: l, } + lighthouses := make(map[iputil.VpnIp]struct{}) + h.lighthouses.Store(&lighthouses) + staticList := make(map[iputil.VpnIp]struct{}) + h.staticList.Store(&staticList) if c.GetBool("stats.lighthouse_metrics", false) { h.metrics = newLighthouseMetrics() @@ -137,31 +138,31 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, } func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} { - return *(*map[iputil.VpnIp]struct{})(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicStaticList)))) + return *lh.staticList.Load() } func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} { - return *(*map[iputil.VpnIp]struct{})(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLighthouses)))) + return *lh.lighthouses.Load() } func (lh *LightHouse) GetRemoteAllowList() *RemoteAllowList { - return (*RemoteAllowList)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicRemoteAllowList)))) + return lh.remoteAllowList.Load() } func (lh *LightHouse) GetLocalAllowList() *LocalAllowList { - return (*LocalAllowList)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLocalAllowList)))) + return lh.localAllowList.Load() } func (lh *LightHouse) GetAdvertiseAddrs() []netIpAndPort { - return *(*[]netIpAndPort)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicAdvertiseAddrs)))) + return *lh.advertiseAddrs.Load() } func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp { - return *(*[]iputil.VpnIp)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicRelaysForMe)))) + return *lh.relaysForMe.Load() } func (lh *LightHouse) GetUpdateInterval() int64 { - return atomic.LoadInt64(&lh.atomicInterval) + return lh.interval.Load() } func (lh *LightHouse) reload(c *config.C, initial bool) error { @@ -188,7 +189,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { advAddrs = append(advAddrs, netIpAndPort{ip: fIp, port: fPort}) } - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicAdvertiseAddrs)), unsafe.Pointer(&advAddrs)) + lh.advertiseAddrs.Store(&advAddrs) if !initial { lh.l.Info("lighthouse.advertise_addrs has changed") @@ -196,10 +197,10 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } if initial || c.HasChanged("lighthouse.interval") { - atomic.StoreInt64(&lh.atomicInterval, int64(c.GetInt("lighthouse.interval", 10))) + lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10))) if !initial { - lh.l.Infof("lighthouse.interval changed to %v", lh.atomicInterval) + lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load()) if lh.updateCancel != nil { // May not always have a running routine @@ -216,7 +217,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) } - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicRemoteAllowList)), unsafe.Pointer(ral)) + lh.remoteAllowList.Store(ral) if !initial { //TODO: a diff will be annoyingly difficult lh.l.Info("lighthouse.remote_allow_list and/or lighthouse.remote_allow_ranges has changed") @@ -229,7 +230,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err) } - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLocalAllowList)), unsafe.Pointer(lal)) + lh.localAllowList.Store(lal) if !initial { //TODO: a diff will be annoyingly difficult lh.l.Info("lighthouse.local_allow_list has changed") @@ -244,7 +245,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return err } - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicStaticList)), unsafe.Pointer(&staticList)) + lh.staticList.Store(&staticList) if !initial { //TODO: we should remove any remote list entries for static hosts that were removed/modified? lh.l.Info("static_host_map has changed") @@ -259,7 +260,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return err } - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLighthouses)), unsafe.Pointer(&lhMap)) + lh.lighthouses.Store(&lhMap) if !initial { //NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic lh.l.Info("lighthouse.hosts has changed") @@ -274,7 +275,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.l.Info("Ignoring relays from config because am_relay is true") } relaysForMe := []iputil.VpnIp{} - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicRelaysForMe)), unsafe.Pointer(&relaysForMe)) + lh.relaysForMe.Store(&relaysForMe) case false: relaysForMe := []iputil.VpnIp{} for _, v := range c.GetStringSlice("relay.relays", nil) { @@ -285,7 +286,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { relaysForMe = append(relaysForMe, iputil.Ip2VpnIp(configRIP)) } } - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicRelaysForMe)), unsafe.Pointer(&relaysForMe)) + lh.relaysForMe.Store(&relaysForMe) } } @@ -460,7 +461,7 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { // AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client -//NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it +// NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, staticList map[iputil.VpnIp]struct{}) { lh.Lock() am := lh.unlockedGetRemoteList(vpnIp) diff --git a/punchy.go b/punchy.go index d81ed83..1ecf7c5 100644 --- a/punchy.go +++ b/punchy.go @@ -9,10 +9,10 @@ import ( ) type Punchy struct { - atomicPunch int32 - atomicRespond int32 - atomicDelay time.Duration - l *logrus.Logger + punch atomic.Bool + respond atomic.Bool + delay atomic.Int64 + l *logrus.Logger } func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy { @@ -36,12 +36,7 @@ func (p *Punchy) reload(c *config.C, initial bool) { yes = c.GetBool("punchy", false) } - if yes { - atomic.StoreInt32(&p.atomicPunch, 1) - } else { - atomic.StoreInt32(&p.atomicPunch, 0) - } - + p.punch.Store(yes) } else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") { //TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.") @@ -56,11 +51,7 @@ func (p *Punchy) reload(c *config.C, initial bool) { yes = c.GetBool("punch_back", false) } - if yes { - atomic.StoreInt32(&p.atomicRespond, 1) - } else { - atomic.StoreInt32(&p.atomicRespond, 0) - } + p.respond.Store(yes) if !initial { p.l.Infof("punchy.respond changed to %v", p.GetRespond()) @@ -69,7 +60,7 @@ func (p *Punchy) reload(c *config.C, initial bool) { //NOTE: this will not apply to any in progress operations, only the next one if initial || c.HasChanged("punchy.delay") { - atomic.StoreInt64((*int64)(&p.atomicDelay), (int64)(c.GetDuration("punchy.delay", time.Second))) + p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second))) if !initial { p.l.Infof("punchy.delay changed to %s", p.GetDelay()) } @@ -77,13 +68,13 @@ func (p *Punchy) reload(c *config.C, initial bool) { } func (p *Punchy) GetPunch() bool { - return atomic.LoadInt32(&p.atomicPunch) == 1 + return p.punch.Load() } func (p *Punchy) GetRespond() bool { - return atomic.LoadInt32(&p.atomicRespond) == 1 + return p.respond.Load() } func (p *Punchy) GetDelay() time.Duration { - return (time.Duration)(atomic.LoadInt64((*int64)(&p.atomicDelay))) + return (time.Duration)(p.delay.Load()) } diff --git a/relay_manager.go b/relay_manager.go index 145e319..95807bd 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -13,9 +13,9 @@ import ( ) type relayManager struct { - l *logrus.Logger - hostmap *HostMap - atomicAmRelay int32 + l *logrus.Logger + hostmap *HostMap + amRelay atomic.Bool } func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c *config.C) *relayManager { @@ -41,18 +41,11 @@ func (rm *relayManager) reload(c *config.C, initial bool) error { } func (rm *relayManager) GetAmRelay() bool { - return atomic.LoadInt32(&rm.atomicAmRelay) == 1 + return rm.amRelay.Load() } func (rm *relayManager) setAmRelay(v bool) { - var val int32 - switch v { - case true: - val = 1 - case false: - val = 0 - } - atomic.StoreInt32(&rm.atomicAmRelay, val) + rm.amRelay.Store(v) } // AddRelay finds an available relay index on the hostmap, and associates the relay info with it. diff --git a/remote_list.go b/remote_list.go index af66891..4b544f6 100644 --- a/remote_list.go +++ b/remote_list.go @@ -130,7 +130,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { // LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. // It will mark the deduplicated address list as dirty, so do not call it unless new information is available -//TODO: this needs to support the allow list list +// TODO: this needs to support the allow list list func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) { r.Lock() defer r.Unlock() diff --git a/wintun/tun.go b/wintun/tun.go index a2dfe5e..c167e70 100644 --- a/wintun/tun.go +++ b/wintun/tun.go @@ -59,18 +59,14 @@ func procyield(cycles uint32) //go:linkname nanotime runtime.nanotime func nanotime() int64 -// // CreateTUN creates a Wintun interface with the given name. Should a Wintun // interface with the same name exist, it is reused. -// func CreateTUN(ifname string, mtu int) (Device, error) { return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) } -// // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // a requested GUID. Should a Wintun interface with the same name exist, it is reused. -// func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) if err != nil {