switch to new sync/atomic helpers in go1.19 (#728)

These new helpers make the code a lot cleaner. I confirmed that the
simple helpers like `atomic.Int64` don't add any extra overhead as they
get inlined by the compiler. `atomic.Pointer` adds an extra method call
as it no longer gets inlined, but we aren't using these on the hot path
so it is probably okay.
This commit is contained in:
Wade Simmons 2022-10-31 13:37:41 -04:00 committed by GitHub
parent a800a48857
commit 9af242dc47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 126 additions and 145 deletions

View File

@ -14,10 +14,10 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Set up Go 1.18 - name: Set up Go 1.19
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: 1.18 go-version: 1.19
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
@ -26,9 +26,9 @@ jobs:
- uses: actions/cache@v2 - uses: actions/cache@v2
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-gofmt1.18-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-gofmt1.19-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-gofmt1.18- ${{ runner.os }}-gofmt1.19-
- name: Install goimports - name: Install goimports
run: | run: |

View File

@ -10,10 +10,10 @@ jobs:
name: Build Linux All name: Build Linux All
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Set up Go 1.18 - name: Set up Go 1.19
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: 1.18 go-version: 1.19
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2
@ -34,10 +34,10 @@ jobs:
name: Build Windows name: Build Windows
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Set up Go 1.18 - name: Set up Go 1.19
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: 1.18 go-version: 1.19
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2
@ -68,10 +68,10 @@ jobs:
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
runs-on: macos-11 runs-on: macos-11
steps: steps:
- name: Set up Go 1.18 - name: Set up Go 1.19
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: 1.18 go-version: 1.19
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2

View File

@ -18,10 +18,10 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Set up Go 1.18 - name: Set up Go 1.19
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: 1.18 go-version: 1.19
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
@ -30,9 +30,9 @@ jobs:
- uses: actions/cache@v2 - uses: actions/cache@v2
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go1.18-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go1.19-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go1.18- ${{ runner.os }}-go1.19-
- name: build - name: build
run: make bin-docker run: make bin-docker

View File

@ -18,10 +18,10 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Set up Go 1.18 - name: Set up Go 1.19
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: 1.18 go-version: 1.19
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
@ -30,9 +30,9 @@ jobs:
- uses: actions/cache@v2 - uses: actions/cache@v2
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go1.18-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go1.19-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go1.18- ${{ runner.os }}-go1.19-
- name: Build - name: Build
run: make all run: make all
@ -57,10 +57,10 @@ jobs:
os: [windows-latest, macos-11] os: [windows-latest, macos-11]
steps: steps:
- name: Set up Go 1.18 - name: Set up Go 1.19
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: 1.18 go-version: 1.19
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
@ -69,9 +69,9 @@ jobs:
- uses: actions/cache@v2 - uses: actions/cache@v2
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go1.18-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go1.19-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go1.18- ${{ runner.os }}-go1.19-
- name: Build nebula - name: Build nebula
run: go build ./cmd/nebula run: go build ./cmd/nebula

View File

@ -1,4 +1,4 @@
GOMINVERSION = 1.18 GOMINVERSION = 1.19
NEBULA_CMD_PATH = "./cmd/nebula" NEBULA_CMD_PATH = "./cmd/nebula"
GO111MODULE = on GO111MODULE = on
export GO111MODULE export GO111MODULE

View File

@ -13,7 +13,7 @@ import (
// A version string that can be set with // A version string that can be set with
// //
// -ldflags "-X main.Build=SOMEVERSION" // -ldflags "-X main.Build=SOMEVERSION"
// //
// at compile-time. // at compile-time.
var Build string var Build string

View File

@ -13,7 +13,7 @@ import (
// A version string that can be set with // A version string that can be set with
// //
// -ldflags "-X main.Build=SOMEVERSION" // -ldflags "-X main.Build=SOMEVERSION"
// //
// at compile-time. // at compile-time.
var Build string var Build string

View File

@ -18,6 +18,20 @@ import (
var vpnIp iputil.VpnIp var vpnIp iputil.VpnIp
func newTestLighthouse() *LightHouse {
lh := &LightHouse{
l: test.NewLogger(),
addrMap: map[iputil.VpnIp]*RemoteList{},
}
lighthouses := map[iputil.VpnIp]struct{}{}
staticList := map[iputil.VpnIp]struct{}{}
lh.lighthouses.Store(&lighthouses)
lh.staticList.Store(&staticList)
return lh
}
func Test_NewConnectionManagerTest(t *testing.T) { func Test_NewConnectionManagerTest(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
@ -35,7 +49,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
rawCertificateNoKey: []byte{}, rawCertificateNoKey: []byte{},
} }
lh := &LightHouse{l: l, atomicStaticList: make(map[iputil.VpnIp]struct{}), atomicLighthouses: make(map[iputil.VpnIp]struct{})} lh := newTestLighthouse()
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &test.NoopTun{}, inside: &test.NoopTun{},
@ -104,7 +118,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
rawCertificateNoKey: []byte{}, rawCertificateNoKey: []byte{},
} }
lh := &LightHouse{l: l, atomicStaticList: make(map[iputil.VpnIp]struct{}), atomicLighthouses: make(map[iputil.VpnIp]struct{})} lh := newTestLighthouse()
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &test.NoopTun{}, inside: &test.NoopTun{},
@ -213,7 +227,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
rawCertificateNoKey: []byte{}, rawCertificateNoKey: []byte{},
} }
lh := &LightHouse{l: l, atomicStaticList: make(map[iputil.VpnIp]struct{}), atomicLighthouses: make(map[iputil.VpnIp]struct{})} lh := newTestLighthouse()
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &test.NoopTun{}, inside: &test.NoopTun{},

View File

@ -14,17 +14,17 @@ import (
const ReplayWindow = 1024 const ReplayWindow = 1024
type ConnectionState struct { type ConnectionState struct {
eKey *NebulaCipherState eKey *NebulaCipherState
dKey *NebulaCipherState dKey *NebulaCipherState
H *noise.HandshakeState H *noise.HandshakeState
certState *CertState certState *CertState
peerCert *cert.NebulaCertificate peerCert *cert.NebulaCertificate
initiator bool initiator bool
atomicMessageCounter uint64 messageCounter atomic.Uint64
window *Bits window *Bits
queueLock sync.Mutex queueLock sync.Mutex
writeLock sync.Mutex writeLock sync.Mutex
ready bool ready bool
} }
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
@ -70,7 +70,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
return json.Marshal(m{ return json.Marshal(m{
"certificate": cs.peerCert, "certificate": cs.peerCert,
"initiator": cs.initiator, "initiator": cs.initiator,
"message_counter": atomic.LoadUint64(&cs.atomicMessageCounter), "message_counter": cs.messageCounter.Load(),
"ready": cs.ready, "ready": cs.ready,
}) })
} }

View File

@ -5,7 +5,6 @@ import (
"net" "net"
"os" "os"
"os/signal" "os/signal"
"sync/atomic"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -219,7 +218,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
} }
if h.ConnectionState != nil { if h.ConnectionState != nil {
chi.MessageCounter = atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter) chi.MessageCounter = h.ConnectionState.messageCounter.Load()
} }
if c := h.GetCert(); c != nil { if c := h.GetCert(); c != nil {

View File

@ -879,7 +879,7 @@ func parsePort(s string) (startPort, endPort int32, err error) {
return return
} }
//TODO: write tests for these // TODO: write tests for these
func setTCPRTTTracking(c *conn, p []byte) { func setTCPRTTTracking(c *conn, p []byte) {
if c.Seq != 0 { if c.Seq != 0 {
return return

View File

@ -13,7 +13,7 @@ type ConntrackCache map[Packet]struct{}
type ConntrackCacheTicker struct { type ConntrackCacheTicker struct {
cacheV uint64 cacheV uint64
cacheTick uint64 cacheTick atomic.Uint64
cache ConntrackCache cache ConntrackCache
} }
@ -35,7 +35,7 @@ func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
func (c *ConntrackCacheTicker) tick(d time.Duration) { func (c *ConntrackCacheTicker) tick(d time.Duration) {
for { for {
time.Sleep(d) time.Sleep(d)
atomic.AddUint64(&c.cacheTick, 1) c.cacheTick.Add(1)
} }
} }
@ -45,7 +45,7 @@ func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
if c == nil { if c == nil {
return nil return nil
} }
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV { if tick := c.cacheTick.Load(); tick != c.cacheV {
c.cacheV = tick c.cacheV = tick
if ll := len(c.cache); ll > 0 { if ll := len(c.cache); ll > 0 {
if l.Level == logrus.DebugLevel { if l.Level == logrus.DebugLevel {

2
go.mod
View File

@ -1,6 +1,6 @@
module github.com/slackhq/nebula module github.com/slackhq/nebula
go 1.18 go 1.19
require ( require (
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be

View File

@ -1,7 +1,6 @@
package nebula package nebula
import ( import (
"sync/atomic"
"time" "time"
"github.com/flynn/noise" "github.com/flynn/noise"
@ -51,7 +50,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
} }
h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
atomic.AddUint64(&ci.atomicMessageCounter, 1) ci.messageCounter.Add(1)
msg, _, _, err := ci.H.WriteMessage(h, hsBytes) msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil { if err != nil {

View File

@ -21,11 +21,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{} mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
lh := &LightHouse{ lh := newTestLighthouse()
atomicStaticList: make(map[iputil.VpnIp]struct{}),
atomicLighthouses: make(map[iputil.VpnIp]struct{}),
addrMap: make(map[iputil.VpnIp]*RemoteList),
}
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig) blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig)
@ -79,12 +75,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{} mw := &mockEncWriter{}
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
lh := &LightHouse{ lh := newTestLighthouse()
addrMap: make(map[iputil.VpnIp]*RemoteList),
l: l,
atomicStaticList: make(map[iputil.VpnIp]struct{}),
atomicLighthouses: make(map[iputil.VpnIp]struct{}),
}
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig) blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig)

View File

@ -18,7 +18,7 @@ import (
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
//const ProbeLen = 100 // const ProbeLen = 100
const PromoteEvery = 1000 const PromoteEvery = 1000
const ReQueryEvery = 5000 const ReQueryEvery = 5000
const MaxRemotes = 10 const MaxRemotes = 10
@ -153,7 +153,7 @@ type HostInfo struct {
remote *udp.Addr remote *udp.Addr
remotes *RemoteList remotes *RemoteList
promoteCounter uint32 promoteCounter atomic.Uint32
ConnectionState *ConnectionState ConnectionState *ConnectionState
handshakeStart time.Time //todo: this an entry in the handshake manager handshakeStart time.Time //todo: this an entry in the handshake manager
HandshakeReady bool //todo: being in the manager means you are ready HandshakeReady bool //todo: being in the manager means you are ready
@ -284,7 +284,6 @@ func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) (
if h, ok := hm.Hosts[vpnIp]; !ok { if h, ok := hm.Hosts[vpnIp]; !ok {
hm.RUnlock() hm.RUnlock()
h = &HostInfo{ h = &HostInfo{
promoteCounter: 0,
vpnIp: vpnIp, vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{ relayState: RelayState{
@ -591,7 +590,7 @@ func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) {
// TryPromoteBest handles re-querying lighthouses and probing for better paths // TryPromoteBest handles re-querying lighthouses and probing for better paths
// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
c := atomic.AddUint32(&i.promoteCounter, 1) c := i.promoteCounter.Add(1)
if c%PromoteEvery == 0 { if c%PromoteEvery == 0 {
// The lock here is currently protecting i.remote access // The lock here is currently protecting i.remote access
i.RLock() i.RLock()
@ -658,7 +657,7 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) {
i.HandshakeComplete = true i.HandshakeComplete = true
//TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen. //TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen.
// Clamping it to 2 gets us out of the woods for now // Clamping it to 2 gets us out of the woods for now
atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2) i.ConnectionState.messageCounter.Store(2)
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {
i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore)) i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore))

View File

@ -1,8 +1,6 @@
package nebula package nebula
import ( import (
"sync/atomic"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
@ -222,7 +220,7 @@ func (f *Interface) SendVia(viaIfc interface{},
) { ) {
via := viaIfc.(*HostInfo) via := viaIfc.(*HostInfo)
relay := relayIfc.(*Relay) relay := relayIfc.(*Relay)
c := atomic.AddUint64(&via.ConnectionState.atomicMessageCounter, 1) c := via.ConnectionState.messageCounter.Add(1)
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c) out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
f.connectionManager.Out(via.vpnIp) f.connectionManager.Out(via.vpnIp)
@ -281,7 +279,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
//TODO: enable if we do more than 1 tun queue //TODO: enable if we do more than 1 tun queue
//ci.writeLock.Lock() //ci.writeLock.Lock()
c := atomic.AddUint64(&ci.atomicMessageCounter, 1) c := ci.messageCounter.Add(1)
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c) out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)

View File

@ -67,7 +67,7 @@ type Interface struct {
routines int routines int
caPool *cert.NebulaCAPool caPool *cert.NebulaCAPool
disconnectInvalid bool disconnectInvalid bool
closed int32 closed atomic.Bool
relayManager *relayManager relayManager *relayManager
sendRecvErrorConfig sendRecvErrorConfig sendRecvErrorConfig sendRecvErrorConfig
@ -253,7 +253,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
for { for {
n, err := reader.Read(packet) n, err := reader.Read(packet)
if err != nil { if err != nil {
if errors.Is(err, os.ErrClosed) && atomic.LoadInt32(&f.closed) != 0 { if errors.Is(err, os.ErrClosed) && f.closed.Load() {
return return
} }
@ -391,7 +391,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
} }
func (f *Interface) Close() error { func (f *Interface) Close() error {
atomic.StoreInt32(&f.closed, 1) f.closed.Store(true)
// Release the tun device // Release the tun device
return f.inside.Close() return f.inside.Close()

View File

@ -9,7 +9,6 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -49,29 +48,29 @@ type LightHouse struct {
// respond with. // respond with.
// - When we are not a lighthouse, this filters which addresses we accept // - When we are not a lighthouse, this filters which addresses we accept
// from lighthouses. // from lighthouses.
atomicRemoteAllowList *RemoteAllowList remoteAllowList atomic.Pointer[RemoteAllowList]
// filters local addresses that we advertise to lighthouses // filters local addresses that we advertise to lighthouses
atomicLocalAllowList *LocalAllowList localAllowList atomic.Pointer[LocalAllowList]
// used to trigger the HandshakeManager when we receive HostQueryReply // used to trigger the HandshakeManager when we receive HostQueryReply
handshakeTrigger chan<- iputil.VpnIp handshakeTrigger chan<- iputil.VpnIp
// atomicStaticList exists to avoid having a bool in each addrMap entry // staticList exists to avoid having a bool in each addrMap entry
// since static should be rare // since static should be rare
atomicStaticList map[iputil.VpnIp]struct{} staticList atomic.Pointer[map[iputil.VpnIp]struct{}]
atomicLighthouses map[iputil.VpnIp]struct{} lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}]
atomicInterval int64 interval atomic.Int64
updateCancel context.CancelFunc updateCancel context.CancelFunc
updateParentCtx context.Context updateParentCtx context.Context
updateUdp udp.EncWriter updateUdp udp.EncWriter
nebulaPort uint32 // 32 bits because protobuf does not have a uint16 nebulaPort uint32 // 32 bits because protobuf does not have a uint16
atomicAdvertiseAddrs []netIpAndPort advertiseAddrs atomic.Pointer[[]netIpAndPort]
// IP's of relays that can be used by peers to access me // IP's of relays that can be used by peers to access me
atomicRelaysForMe []iputil.VpnIp relaysForMe atomic.Pointer[[]iputil.VpnIp]
metrics *MessageMetrics metrics *MessageMetrics
metricHolepunchTx metrics.Counter metricHolepunchTx metrics.Counter
@ -98,18 +97,20 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet,
ones, _ := myVpnNet.Mask.Size() ones, _ := myVpnNet.Mask.Size()
h := LightHouse{ h := LightHouse{
amLighthouse: amLighthouse, amLighthouse: amLighthouse,
myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP), myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP),
myVpnZeros: iputil.VpnIp(32 - ones), myVpnZeros: iputil.VpnIp(32 - ones),
myVpnNet: myVpnNet, myVpnNet: myVpnNet,
addrMap: make(map[iputil.VpnIp]*RemoteList), addrMap: make(map[iputil.VpnIp]*RemoteList),
nebulaPort: nebulaPort, nebulaPort: nebulaPort,
atomicLighthouses: make(map[iputil.VpnIp]struct{}), punchConn: pc,
atomicStaticList: make(map[iputil.VpnIp]struct{}), punchy: p,
punchConn: pc, l: l,
punchy: p,
l: l,
} }
lighthouses := make(map[iputil.VpnIp]struct{})
h.lighthouses.Store(&lighthouses)
staticList := make(map[iputil.VpnIp]struct{})
h.staticList.Store(&staticList)
if c.GetBool("stats.lighthouse_metrics", false) { if c.GetBool("stats.lighthouse_metrics", false) {
h.metrics = newLighthouseMetrics() h.metrics = newLighthouseMetrics()
@ -137,31 +138,31 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet,
} }
func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} { func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} {
return *(*map[iputil.VpnIp]struct{})(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicStaticList)))) return *lh.staticList.Load()
} }
func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} { func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} {
return *(*map[iputil.VpnIp]struct{})(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLighthouses)))) return *lh.lighthouses.Load()
} }
func (lh *LightHouse) GetRemoteAllowList() *RemoteAllowList { func (lh *LightHouse) GetRemoteAllowList() *RemoteAllowList {
return (*RemoteAllowList)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicRemoteAllowList)))) return lh.remoteAllowList.Load()
} }
func (lh *LightHouse) GetLocalAllowList() *LocalAllowList { func (lh *LightHouse) GetLocalAllowList() *LocalAllowList {
return (*LocalAllowList)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLocalAllowList)))) return lh.localAllowList.Load()
} }
func (lh *LightHouse) GetAdvertiseAddrs() []netIpAndPort { func (lh *LightHouse) GetAdvertiseAddrs() []netIpAndPort {
return *(*[]netIpAndPort)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicAdvertiseAddrs)))) return *lh.advertiseAddrs.Load()
} }
func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp { func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp {
return *(*[]iputil.VpnIp)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicRelaysForMe)))) return *lh.relaysForMe.Load()
} }
func (lh *LightHouse) GetUpdateInterval() int64 { func (lh *LightHouse) GetUpdateInterval() int64 {
return atomic.LoadInt64(&lh.atomicInterval) return lh.interval.Load()
} }
func (lh *LightHouse) reload(c *config.C, initial bool) error { func (lh *LightHouse) reload(c *config.C, initial bool) error {
@ -188,7 +189,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
advAddrs = append(advAddrs, netIpAndPort{ip: fIp, port: fPort}) advAddrs = append(advAddrs, netIpAndPort{ip: fIp, port: fPort})
} }
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicAdvertiseAddrs)), unsafe.Pointer(&advAddrs)) lh.advertiseAddrs.Store(&advAddrs)
if !initial { if !initial {
lh.l.Info("lighthouse.advertise_addrs has changed") lh.l.Info("lighthouse.advertise_addrs has changed")
@ -196,10 +197,10 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
} }
if initial || c.HasChanged("lighthouse.interval") { if initial || c.HasChanged("lighthouse.interval") {
atomic.StoreInt64(&lh.atomicInterval, int64(c.GetInt("lighthouse.interval", 10))) lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10)))
if !initial { if !initial {
lh.l.Infof("lighthouse.interval changed to %v", lh.atomicInterval) lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load())
if lh.updateCancel != nil { if lh.updateCancel != nil {
// May not always have a running routine // May not always have a running routine
@ -216,7 +217,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) return util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
} }
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicRemoteAllowList)), unsafe.Pointer(ral)) lh.remoteAllowList.Store(ral)
if !initial { if !initial {
//TODO: a diff will be annoyingly difficult //TODO: a diff will be annoyingly difficult
lh.l.Info("lighthouse.remote_allow_list and/or lighthouse.remote_allow_ranges has changed") lh.l.Info("lighthouse.remote_allow_list and/or lighthouse.remote_allow_ranges has changed")
@ -229,7 +230,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err) return util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
} }
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLocalAllowList)), unsafe.Pointer(lal)) lh.localAllowList.Store(lal)
if !initial { if !initial {
//TODO: a diff will be annoyingly difficult //TODO: a diff will be annoyingly difficult
lh.l.Info("lighthouse.local_allow_list has changed") lh.l.Info("lighthouse.local_allow_list has changed")
@ -244,7 +245,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return err return err
} }
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicStaticList)), unsafe.Pointer(&staticList)) lh.staticList.Store(&staticList)
if !initial { if !initial {
//TODO: we should remove any remote list entries for static hosts that were removed/modified? //TODO: we should remove any remote list entries for static hosts that were removed/modified?
lh.l.Info("static_host_map has changed") lh.l.Info("static_host_map has changed")
@ -259,7 +260,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return err return err
} }
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLighthouses)), unsafe.Pointer(&lhMap)) lh.lighthouses.Store(&lhMap)
if !initial { if !initial {
//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic //NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
lh.l.Info("lighthouse.hosts has changed") lh.l.Info("lighthouse.hosts has changed")
@ -274,7 +275,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
lh.l.Info("Ignoring relays from config because am_relay is true") lh.l.Info("Ignoring relays from config because am_relay is true")
} }
relaysForMe := []iputil.VpnIp{} relaysForMe := []iputil.VpnIp{}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicRelaysForMe)), unsafe.Pointer(&relaysForMe)) lh.relaysForMe.Store(&relaysForMe)
case false: case false:
relaysForMe := []iputil.VpnIp{} relaysForMe := []iputil.VpnIp{}
for _, v := range c.GetStringSlice("relay.relays", nil) { for _, v := range c.GetStringSlice("relay.relays", nil) {
@ -285,7 +286,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
relaysForMe = append(relaysForMe, iputil.Ip2VpnIp(configRIP)) relaysForMe = append(relaysForMe, iputil.Ip2VpnIp(configRIP))
} }
} }
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicRelaysForMe)), unsafe.Pointer(&relaysForMe)) lh.relaysForMe.Store(&relaysForMe)
} }
} }
@ -460,7 +461,7 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner // AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner
// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
//NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, staticList map[iputil.VpnIp]struct{}) { func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, staticList map[iputil.VpnIp]struct{}) {
lh.Lock() lh.Lock()
am := lh.unlockedGetRemoteList(vpnIp) am := lh.unlockedGetRemoteList(vpnIp)

View File

@ -9,10 +9,10 @@ import (
) )
type Punchy struct { type Punchy struct {
atomicPunch int32 punch atomic.Bool
atomicRespond int32 respond atomic.Bool
atomicDelay time.Duration delay atomic.Int64
l *logrus.Logger l *logrus.Logger
} }
func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy { func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy {
@ -36,12 +36,7 @@ func (p *Punchy) reload(c *config.C, initial bool) {
yes = c.GetBool("punchy", false) yes = c.GetBool("punchy", false)
} }
if yes { p.punch.Store(yes)
atomic.StoreInt32(&p.atomicPunch, 1)
} else {
atomic.StoreInt32(&p.atomicPunch, 0)
}
} else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") { } else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
//TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here //TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here
p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.") p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.")
@ -56,11 +51,7 @@ func (p *Punchy) reload(c *config.C, initial bool) {
yes = c.GetBool("punch_back", false) yes = c.GetBool("punch_back", false)
} }
if yes { p.respond.Store(yes)
atomic.StoreInt32(&p.atomicRespond, 1)
} else {
atomic.StoreInt32(&p.atomicRespond, 0)
}
if !initial { if !initial {
p.l.Infof("punchy.respond changed to %v", p.GetRespond()) p.l.Infof("punchy.respond changed to %v", p.GetRespond())
@ -69,7 +60,7 @@ func (p *Punchy) reload(c *config.C, initial bool) {
//NOTE: this will not apply to any in progress operations, only the next one //NOTE: this will not apply to any in progress operations, only the next one
if initial || c.HasChanged("punchy.delay") { if initial || c.HasChanged("punchy.delay") {
atomic.StoreInt64((*int64)(&p.atomicDelay), (int64)(c.GetDuration("punchy.delay", time.Second))) p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second)))
if !initial { if !initial {
p.l.Infof("punchy.delay changed to %s", p.GetDelay()) p.l.Infof("punchy.delay changed to %s", p.GetDelay())
} }
@ -77,13 +68,13 @@ func (p *Punchy) reload(c *config.C, initial bool) {
} }
func (p *Punchy) GetPunch() bool { func (p *Punchy) GetPunch() bool {
return atomic.LoadInt32(&p.atomicPunch) == 1 return p.punch.Load()
} }
func (p *Punchy) GetRespond() bool { func (p *Punchy) GetRespond() bool {
return atomic.LoadInt32(&p.atomicRespond) == 1 return p.respond.Load()
} }
func (p *Punchy) GetDelay() time.Duration { func (p *Punchy) GetDelay() time.Duration {
return (time.Duration)(atomic.LoadInt64((*int64)(&p.atomicDelay))) return (time.Duration)(p.delay.Load())
} }

View File

@ -13,9 +13,9 @@ import (
) )
type relayManager struct { type relayManager struct {
l *logrus.Logger l *logrus.Logger
hostmap *HostMap hostmap *HostMap
atomicAmRelay int32 amRelay atomic.Bool
} }
func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c *config.C) *relayManager { func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c *config.C) *relayManager {
@ -41,18 +41,11 @@ func (rm *relayManager) reload(c *config.C, initial bool) error {
} }
func (rm *relayManager) GetAmRelay() bool { func (rm *relayManager) GetAmRelay() bool {
return atomic.LoadInt32(&rm.atomicAmRelay) == 1 return rm.amRelay.Load()
} }
func (rm *relayManager) setAmRelay(v bool) { func (rm *relayManager) setAmRelay(v bool) {
var val int32 rm.amRelay.Store(v)
switch v {
case true:
val = 1
case false:
val = 0
}
atomic.StoreInt32(&rm.atomicAmRelay, val)
} }
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it. // AddRelay finds an available relay index on the hostmap, and associates the relay info with it.

View File

@ -130,7 +130,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
// LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr // LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr
// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
// It will mark the deduplicated address list as dirty, so do not call it unless new information is available // It will mark the deduplicated address list as dirty, so do not call it unless new information is available
//TODO: this needs to support the allow list list // TODO: this needs to support the allow list list
func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) { func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()

View File

@ -59,18 +59,14 @@ func procyield(cycles uint32)
//go:linkname nanotime runtime.nanotime //go:linkname nanotime runtime.nanotime
func nanotime() int64 func nanotime() int64
//
// CreateTUN creates a Wintun interface with the given name. Should a Wintun // CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused. // interface with the same name exist, it is reused.
//
func CreateTUN(ifname string, mtu int) (Device, error) { func CreateTUN(ifname string, mtu int) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
} }
//
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused. // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
if err != nil { if err != nil {