nebula/connection_state.go

88 lines
2.3 KiB
Go
Raw Normal View History

2019-11-19 10:00:20 -07:00
package nebula
import (
"crypto/rand"
"encoding/json"
"sync"
"sync/atomic"
2019-11-19 10:00:20 -07:00
"github.com/flynn/noise"
2021-03-26 08:46:30 -06:00
"github.com/sirupsen/logrus"
2019-11-19 10:00:20 -07:00
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/noiseutil"
2019-11-19 10:00:20 -07:00
)
const ReplayWindow = 1024
type ConnectionState struct {
eKey *NebulaCipherState
dKey *NebulaCipherState
H *noise.HandshakeState
myCert *cert.NebulaCertificate
peerCert *cert.NebulaCertificate
initiator bool
messageCounter atomic.Uint64
window *Bits
writeLock sync.Mutex
2019-11-19 10:00:20 -07:00
}
func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
var dhFunc noise.DHFunc
switch certState.Certificate.Details.Curve {
case cert.Curve_CURVE25519:
dhFunc = noise.DH25519
case cert.Curve_P256:
dhFunc = noiseutil.DHP256
default:
l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
return nil
}
var cs noise.CipherSuite
if cipher == "chachapoly" {
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
} else {
cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
2019-11-19 10:00:20 -07:00
}
static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
2019-11-19 10:00:20 -07:00
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
2021-03-26 08:46:30 -06:00
b.Update(l, 0)
2019-11-19 10:00:20 -07:00
hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: cs,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: static,
PresharedKey: psk,
PresharedKeyPlacement: pskStage,
})
if err != nil {
return nil
}
// The queue and ready params prevent a counter race that would happen when
// sending stored packets and simultaneously accepting new traffic.
ci := &ConnectionState{
H: hs,
initiator: initiator,
window: b,
myCert: certState.Certificate,
2019-11-19 10:00:20 -07:00
}
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
ci.messageCounter.Add(2)
2019-11-19 10:00:20 -07:00
return ci
}
func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
return json.Marshal(m{
"certificate": cs.peerCert,
"initiator": cs.initiator,
"message_counter": cs.messageCounter.Load(),
2019-11-19 10:00:20 -07:00
})
}