mirror of https://github.com/slackhq/nebula.git
211 lines
4.5 KiB
Go
211 lines
4.5 KiB
Go
package nebula
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/gaissmai/bart"
|
|
"github.com/miekg/dns"
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/slackhq/nebula/config"
|
|
)
|
|
|
|
// This whole thing should be rewritten to use context
|
|
|
|
var dnsR *dnsRecords
|
|
var dnsServer *dns.Server
|
|
var dnsAddr string
|
|
|
|
type dnsRecords struct {
|
|
sync.RWMutex
|
|
l *logrus.Logger
|
|
dnsMap4 map[string]netip.Addr
|
|
dnsMap6 map[string]netip.Addr
|
|
hostMap *HostMap
|
|
myVpnAddrsTable *bart.Table[struct{}]
|
|
}
|
|
|
|
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
|
return &dnsRecords{
|
|
l: l,
|
|
dnsMap4: make(map[string]netip.Addr),
|
|
dnsMap6: make(map[string]netip.Addr),
|
|
hostMap: hostMap,
|
|
myVpnAddrsTable: cs.myVpnAddrsTable,
|
|
}
|
|
}
|
|
|
|
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
|
data = strings.ToLower(data)
|
|
d.RLock()
|
|
defer d.RUnlock()
|
|
switch q {
|
|
case dns.TypeA:
|
|
if r, ok := d.dnsMap4[data]; ok {
|
|
return r
|
|
}
|
|
case dns.TypeAAAA:
|
|
if r, ok := d.dnsMap6[data]; ok {
|
|
return r
|
|
}
|
|
}
|
|
|
|
return netip.Addr{}
|
|
}
|
|
|
|
func (d *dnsRecords) QueryCert(data string) string {
|
|
ip, err := netip.ParseAddr(data[:len(data)-1])
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
hostinfo := d.hostMap.QueryVpnAddr(ip)
|
|
if hostinfo == nil {
|
|
return ""
|
|
}
|
|
|
|
q := hostinfo.GetCert()
|
|
if q == nil {
|
|
return ""
|
|
}
|
|
|
|
b, err := q.Certificate.MarshalJSON()
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return string(b)
|
|
}
|
|
|
|
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
|
|
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
|
|
host = strings.ToLower(host)
|
|
d.Lock()
|
|
defer d.Unlock()
|
|
haveV4 := false
|
|
haveV6 := false
|
|
for _, addr := range addresses {
|
|
if addr.Is4() && !haveV4 {
|
|
d.dnsMap4[host] = addr
|
|
haveV4 = true
|
|
} else if addr.Is6() && !haveV6 {
|
|
d.dnsMap6[host] = addr
|
|
haveV6 = true
|
|
}
|
|
if haveV4 && haveV6 {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
|
a, _, _ := net.SplitHostPort(addr)
|
|
b, err := netip.ParseAddr(a)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
if b.IsLoopback() {
|
|
return true
|
|
}
|
|
|
|
_, found := d.myVpnAddrsTable.Lookup(b)
|
|
return found //if we found it in this table, it's good
|
|
}
|
|
|
|
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
|
for _, q := range m.Question {
|
|
switch q.Qtype {
|
|
case dns.TypeA, dns.TypeAAAA:
|
|
qType := dns.TypeToString[q.Qtype]
|
|
d.l.Debugf("Query for %s %s", qType, q.Name)
|
|
ip := d.Query(q.Qtype, q.Name)
|
|
if ip.IsValid() {
|
|
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
|
if err == nil {
|
|
m.Answer = append(m.Answer, rr)
|
|
}
|
|
}
|
|
case dns.TypeTXT:
|
|
// We only answer these queries from nebula nodes or localhost
|
|
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
|
|
return
|
|
}
|
|
d.l.Debugf("Query for TXT %s", q.Name)
|
|
ip := d.QueryCert(q.Name)
|
|
if ip != "" {
|
|
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
|
if err == nil {
|
|
m.Answer = append(m.Answer, rr)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(m.Answer) == 0 {
|
|
m.Rcode = dns.RcodeNameError
|
|
}
|
|
}
|
|
|
|
func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|
m := new(dns.Msg)
|
|
m.SetReply(r)
|
|
m.Compress = false
|
|
|
|
switch r.Opcode {
|
|
case dns.OpcodeQuery:
|
|
d.parseQuery(m, w)
|
|
}
|
|
|
|
w.WriteMsg(m)
|
|
}
|
|
|
|
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
|
dnsR = newDnsRecords(l, cs, hostMap)
|
|
|
|
// attach request handler func
|
|
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
|
|
|
c.RegisterReloadCallback(func(c *config.C) {
|
|
reloadDns(l, c)
|
|
})
|
|
|
|
return func() {
|
|
startDns(l, c)
|
|
}
|
|
}
|
|
|
|
func getDnsServerAddr(c *config.C) string {
|
|
dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
|
|
// Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
|
|
if dnsHost == "[::]" {
|
|
dnsHost = "::"
|
|
}
|
|
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
|
}
|
|
|
|
func startDns(l *logrus.Logger, c *config.C) {
|
|
dnsAddr = getDnsServerAddr(c)
|
|
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
|
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
|
|
err := dnsServer.ListenAndServe()
|
|
defer dnsServer.Shutdown()
|
|
if err != nil {
|
|
l.Errorf("Failed to start server: %s\n ", err.Error())
|
|
}
|
|
}
|
|
|
|
func reloadDns(l *logrus.Logger, c *config.C) {
|
|
if dnsAddr == getDnsServerAddr(c) {
|
|
l.Debug("No DNS server config change detected")
|
|
return
|
|
}
|
|
|
|
l.Debug("Restarting DNS server")
|
|
dnsServer.Shutdown()
|
|
go startDns(l, c)
|
|
}
|