mirror of https://github.com/slackhq/nebula.git
Fix most known data races (#396)
This change fixes all of the known data races that `make smoke-docker-race` finds, except for one. Most of these races are around the handshake phase for a hostinfo, so we add a RWLock to the hostinfo and Lock during each of the handshake stages. Some of the other races are around consistently using `atomic` around the `messageCounter` field. To make this harder to mess up, I have renamed the field to `atomicMessageCounter` (I also removed the unnecessary extra pointer deference as we can just point directly to the struct field). The last remaining data race is around reading `ConnectionInfo.ready`, which is a boolean that is only written to once when the handshake has finished. Due to it being in the hot path for packets and the rare case that this could actually be an issue, holding off on fixing that one for now. here is the results of `make smoke-docker-race`: before: lighthouse1: Found 2 data race(s) host2: Found 36 data race(s) host3: Found 17 data race(s) host4: Found 31 data race(s) after: host2: Found 1 data race(s) host4: Found 1 data race(s) Fixes: #147 Fixes: #226 Fixes: #283 Fixes: #316
This commit is contained in:
parent
29c5f31f90
commit
d604270966
|
@ -49,9 +49,8 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||
// Add an ip we have established a connection w/ to hostmap
|
||||
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
|
||||
hostinfo.ConnectionState = &ConnectionState{
|
||||
certState: cs,
|
||||
H: &noise.HandshakeState{},
|
||||
messageCounter: new(uint64),
|
||||
certState: cs,
|
||||
H: &noise.HandshakeState{},
|
||||
}
|
||||
|
||||
// We saw traffic out to vpnIP
|
||||
|
@ -115,9 +114,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||
// Add an ip we have established a connection w/ to hostmap
|
||||
hostinfo := nc.hostMap.AddVpnIP(vpnIP)
|
||||
hostinfo.ConnectionState = &ConnectionState{
|
||||
certState: cs,
|
||||
H: &noise.HandshakeState{},
|
||||
messageCounter: new(uint64),
|
||||
certState: cs,
|
||||
H: &noise.HandshakeState{},
|
||||
}
|
||||
|
||||
// We saw traffic out to vpnIP
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
|
@ -12,17 +13,17 @@ import (
|
|||
const ReplayWindow = 1024
|
||||
|
||||
type ConnectionState struct {
|
||||
eKey *NebulaCipherState
|
||||
dKey *NebulaCipherState
|
||||
H *noise.HandshakeState
|
||||
certState *CertState
|
||||
peerCert *cert.NebulaCertificate
|
||||
initiator bool
|
||||
messageCounter *uint64
|
||||
window *Bits
|
||||
queueLock sync.Mutex
|
||||
writeLock sync.Mutex
|
||||
ready bool
|
||||
eKey *NebulaCipherState
|
||||
dKey *NebulaCipherState
|
||||
H *noise.HandshakeState
|
||||
certState *CertState
|
||||
peerCert *cert.NebulaCertificate
|
||||
initiator bool
|
||||
atomicMessageCounter uint64
|
||||
window *Bits
|
||||
queueLock sync.Mutex
|
||||
writeLock sync.Mutex
|
||||
ready bool
|
||||
}
|
||||
|
||||
func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
||||
|
@ -54,12 +55,11 @@ func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePa
|
|||
// The queue and ready params prevent a counter race that would happen when
|
||||
// sending stored packets and simultaneously accepting new traffic.
|
||||
ci := &ConnectionState{
|
||||
H: hs,
|
||||
initiator: initiator,
|
||||
window: b,
|
||||
ready: false,
|
||||
certState: curCertState,
|
||||
messageCounter: new(uint64),
|
||||
H: hs,
|
||||
initiator: initiator,
|
||||
window: b,
|
||||
ready: false,
|
||||
certState: curCertState,
|
||||
}
|
||||
|
||||
return ci
|
||||
|
@ -69,7 +69,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
|
|||
return json.Marshal(m{
|
||||
"certificate": cs.peerCert,
|
||||
"initiator": cs.initiator,
|
||||
"message_counter": cs.messageCounter,
|
||||
"message_counter": atomic.LoadUint64(&cs.atomicMessageCounter),
|
||||
"ready": cs.ready,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -156,7 +157,7 @@ func copyHostInfo(h *HostInfo) ControlHostInfo {
|
|||
RemoteIndex: h.remoteIndexId,
|
||||
RemoteAddrs: make([]udpAddr, len(addrs), len(addrs)),
|
||||
CachedPackets: len(h.packetStore),
|
||||
MessageCounter: *h.ConnectionState.messageCounter,
|
||||
MessageCounter: atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter),
|
||||
}
|
||||
|
||||
if c := h.GetCert(); c != nil {
|
||||
|
|
|
@ -43,15 +43,13 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
|
|||
},
|
||||
Signature: []byte{1, 2, 1, 2, 1, 3},
|
||||
}
|
||||
counter := uint64(0)
|
||||
|
||||
remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)}
|
||||
hm.Add(ip2int(ipNet.IP), &HostInfo{
|
||||
remote: remote1,
|
||||
Remotes: remotes,
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: crt,
|
||||
messageCounter: &counter,
|
||||
peerCert: crt,
|
||||
},
|
||||
remoteIndexId: 200,
|
||||
localIndexId: 201,
|
||||
|
@ -62,8 +60,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
|
|||
remote: remote1,
|
||||
Remotes: remotes,
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: nil,
|
||||
messageCounter: &counter,
|
||||
peerCert: nil,
|
||||
},
|
||||
remoteIndexId: 200,
|
||||
localIndexId: 201,
|
||||
|
|
|
@ -54,7 +54,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
|
|||
}
|
||||
|
||||
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, 0, 1)
|
||||
atomic.AddUint64(ci.messageCounter, 1)
|
||||
atomic.AddUint64(&ci.atomicMessageCounter, 1)
|
||||
|
||||
msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
|
||||
if err != nil {
|
||||
|
@ -99,26 +99,31 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||
}
|
||||
|
||||
hostinfo, _ := f.handshakeManager.pendingHostMap.QueryReverseIndex(hs.Details.InitiatorIndex)
|
||||
if hostinfo != nil && bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) {
|
||||
if msg, ok := hostinfo.HandshakePacket[2]; ok {
|
||||
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
|
||||
err := f.outside.WriteTo(msg, addr)
|
||||
if err != nil {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
WithError(err).Error("Failed to send handshake message")
|
||||
} else {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
Info("Handshake message sent")
|
||||
}
|
||||
return false
|
||||
}
|
||||
if hostinfo != nil {
|
||||
hostinfo.RLock()
|
||||
defer hostinfo.RUnlock()
|
||||
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cached", true).
|
||||
WithField("packets", hostinfo.HandshakePacket).
|
||||
Error("Seen this handshake packet already but don't have a cached packet to return")
|
||||
if bytes.Equal(hostinfo.HandshakePacket[0], packet[HeaderLen:]) {
|
||||
if msg, ok := hostinfo.HandshakePacket[2]; ok {
|
||||
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
|
||||
err := f.outside.WriteTo(msg, addr)
|
||||
if err != nil {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
WithError(err).Error("Failed to send handshake message")
|
||||
} else {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
Info("Handshake message sent")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cached", true).
|
||||
WithField("packets", hostinfo.HandshakePacket).
|
||||
Error("Seen this handshake packet already but don't have a cached packet to return")
|
||||
}
|
||||
}
|
||||
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
|
||||
|
@ -150,6 +155,9 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||
|
||||
return true
|
||||
}
|
||||
hostinfo.Lock()
|
||||
defer hostinfo.Unlock()
|
||||
|
||||
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
|
@ -272,6 +280,8 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
|
|||
if hostinfo == nil {
|
||||
return true
|
||||
}
|
||||
hostinfo.Lock()
|
||||
defer hostinfo.Unlock()
|
||||
|
||||
if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) {
|
||||
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
|
||||
|
|
|
@ -103,6 +103,8 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
hostinfo.Lock()
|
||||
defer hostinfo.Unlock()
|
||||
|
||||
// If we haven't finished the handshake and we haven't hit max retries, query
|
||||
// lighthouse and then send the handshake packet again.
|
||||
|
|
21
hostmap.go
21
hostmap.go
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
|
@ -35,6 +36,8 @@ type HostMap struct {
|
|||
}
|
||||
|
||||
type HostInfo struct {
|
||||
sync.RWMutex
|
||||
|
||||
remote *udpAddr
|
||||
Remotes []*HostInfoDest
|
||||
promoteCounter uint32
|
||||
|
@ -231,6 +234,9 @@ func (hm *HostMap) DeleteIndex(index uint32) {
|
|||
hm.Lock()
|
||||
hostinfo, ok := hm.Indexes[index]
|
||||
if ok {
|
||||
hostinfo.Lock()
|
||||
defer hostinfo.Unlock()
|
||||
|
||||
delete(hm.Indexes, index)
|
||||
delete(hm.RemoteIndexes, hostinfo.remoteIndexId)
|
||||
|
||||
|
@ -513,8 +519,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
|
|||
return
|
||||
}
|
||||
|
||||
i.promoteCounter++
|
||||
if i.promoteCounter%PromoteEvery == 0 {
|
||||
if atomic.AddUint32(&i.promoteCounter, 1)&PromoteEvery == 0 {
|
||||
// return early if we are already on a preferred remote
|
||||
rIP := udp2ip(i.remote)
|
||||
for _, l := range preferredRanges {
|
||||
|
@ -615,10 +620,12 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
|
|||
copy(tempPacket, packet)
|
||||
//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
|
||||
i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
|
||||
i.logger().
|
||||
WithField("length", len(i.packetStore)).
|
||||
WithField("stored", true).
|
||||
Debugf("Packet store")
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
i.logger().
|
||||
WithField("length", len(i.packetStore)).
|
||||
WithField("stored", true).
|
||||
Debugf("Packet store")
|
||||
}
|
||||
|
||||
} else if l.Level >= logrus.DebugLevel {
|
||||
i.logger().
|
||||
|
@ -638,7 +645,7 @@ func (i *HostInfo) handshakeComplete() {
|
|||
i.HandshakeComplete = true
|
||||
//TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen.
|
||||
// Clamping it to 2 gets us out of the woods for now
|
||||
*i.ConnectionState.messageCounter = 2
|
||||
atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2)
|
||||
i.logger().Debugf("Sending %d stored packets", len(i.packetStore))
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
|
|
31
inside.go
31
inside.go
|
@ -103,16 +103,21 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
|
|||
|
||||
// If we have already created the handshake packet, we don't want to call the function at all.
|
||||
if !hostinfo.HandshakeReady {
|
||||
ixHandshakeStage0(f, vpnIp, hostinfo)
|
||||
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
|
||||
//xx_handshakeStage0(f, ip, hostinfo)
|
||||
hostinfo.Lock()
|
||||
defer hostinfo.Unlock()
|
||||
|
||||
// If this is a static host, we don't need to wait for the HostQueryReply
|
||||
// We can trigger the handshake right now
|
||||
if _, ok := f.lightHouse.staticList[vpnIp]; ok {
|
||||
select {
|
||||
case f.handshakeManager.trigger <- vpnIp:
|
||||
default:
|
||||
if !hostinfo.HandshakeReady {
|
||||
ixHandshakeStage0(f, vpnIp, hostinfo)
|
||||
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
|
||||
//xx_handshakeStage0(f, ip, hostinfo)
|
||||
|
||||
// If this is a static host, we don't need to wait for the HostQueryReply
|
||||
// We can trigger the handshake right now
|
||||
if _, ok := f.lightHouse.staticList[vpnIp]; ok {
|
||||
select {
|
||||
case f.handshakeManager.trigger <- vpnIp:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -139,8 +144,8 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
|
|||
return
|
||||
}
|
||||
|
||||
f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
|
||||
if f.lightHouse != nil && *hostInfo.ConnectionState.messageCounter%5000 == 0 {
|
||||
messageCounter := f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
|
||||
if f.lightHouse != nil && messageCounter%5000 == 0 {
|
||||
f.lightHouse.Query(fp.RemoteIP, f)
|
||||
}
|
||||
}
|
||||
|
@ -223,7 +228,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
|
|||
var err error
|
||||
//TODO: enable if we do more than 1 tun queue
|
||||
//ci.writeLock.Lock()
|
||||
c := atomic.AddUint64(ci.messageCounter, 1)
|
||||
c := atomic.AddUint64(&ci.atomicMessageCounter, 1)
|
||||
|
||||
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
|
||||
out = HeaderEncode(out, Version, uint8(t), uint8(st), hostinfo.remoteIndexId, c)
|
||||
|
@ -247,7 +252,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
|
|||
if err != nil {
|
||||
hostinfo.logger().WithError(err).
|
||||
WithField("udpAddr", remote).WithField("counter", c).
|
||||
WithField("attemptedCounter", ci.messageCounter).
|
||||
WithField("attemptedCounter", c).
|
||||
Error("Failed to encrypt outgoing packet")
|
||||
return c
|
||||
}
|
||||
|
|
10
interface.go
10
interface.go
|
@ -134,11 +134,6 @@ func (f *Interface) run() {
|
|||
|
||||
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
||||
|
||||
// Launch n queues to read packets from udp
|
||||
for i := 0; i < f.routines; i++ {
|
||||
go f.listenOut(i)
|
||||
}
|
||||
|
||||
// Prepare n tun queues
|
||||
var reader io.ReadWriteCloser = f.inside
|
||||
for i := 0; i < f.routines; i++ {
|
||||
|
@ -155,6 +150,11 @@ func (f *Interface) run() {
|
|||
l.Fatal(err)
|
||||
}
|
||||
|
||||
// Launch n queues to read packets from udp
|
||||
for i := 0; i < f.routines; i++ {
|
||||
go f.listenOut(i)
|
||||
}
|
||||
|
||||
// Launch n queues to read packets from tun dev
|
||||
for i := 0; i < f.routines; i++ {
|
||||
go f.listenIn(f.readers[i], i)
|
||||
|
|
3
ssh.go
3
ssh.go
|
@ -11,6 +11,7 @@ import (
|
|||
"reflect"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -353,7 +354,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
|
|||
}
|
||||
|
||||
if v.ConnectionState != nil {
|
||||
h["messageCounter"] = v.ConnectionState.messageCounter
|
||||
h["messageCounter"] = atomic.LoadUint64(&v.ConnectionState.atomicMessageCounter)
|
||||
}
|
||||
|
||||
d[x] = h
|
||||
|
|
Loading…
Reference in New Issue