mirror of https://github.com/slackhq/nebula.git
Support UDP dialling with gvisor (#1181)
This commit is contained in:
parent
0736cfa562
commit
3dc56e1184
|
@ -4,6 +4,7 @@ import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"net"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/service"
|
"github.com/slackhq/nebula/service"
|
||||||
|
@ -54,16 +55,16 @@ pki:
|
||||||
cert: /home/rice/Developer/nebula-config/app.crt
|
cert: /home/rice/Developer/nebula-config/app.crt
|
||||||
key: /home/rice/Developer/nebula-config/app.key
|
key: /home/rice/Developer/nebula-config/app.key
|
||||||
`
|
`
|
||||||
var config config.C
|
var cfg config.C
|
||||||
if err := config.LoadString(configStr); err != nil {
|
if err := cfg.LoadString(configStr); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
service, err := service.New(&config)
|
svc, err := service.New(&cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ln, err := service.Listen("tcp", ":1234")
|
ln, err := svc.Listen("tcp", ":1234")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -73,16 +74,24 @@ pki:
|
||||||
log.Printf("accept error: %s", err)
|
log.Printf("accept error: %s", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func(conn net.Conn) {
|
||||||
|
_ = conn.Close()
|
||||||
|
}(conn)
|
||||||
|
|
||||||
log.Printf("got connection")
|
log.Printf("got connection")
|
||||||
|
|
||||||
conn.Write([]byte("hello world\n"))
|
_, err = conn.Write([]byte("hello world\n"))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("write error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
scanner := bufio.NewScanner(conn)
|
scanner := bufio.NewScanner(conn)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
message := scanner.Text()
|
message := scanner.Text()
|
||||||
fmt.Fprintf(conn, "echo: %q\n", message)
|
_, err = fmt.Fprintf(conn, "echo: %q\n", message)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("write error: %s", err)
|
||||||
|
}
|
||||||
log.Printf("got message %q", message)
|
log.Printf("got message %q", message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,8 +101,8 @@ pki:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
service.Close()
|
_ = svc.Close()
|
||||||
if err := service.Wait(); err != nil {
|
if err := svc.Wait(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -153,24 +154,48 @@ func New(config *config.C) (*Service, error) {
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialContext dials the provided address. Currently only TCP is supported.
|
func getProtocolNumber(addr netip.Addr) tcpip.NetworkProtocolNumber {
|
||||||
func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
if addr.Is6() {
|
||||||
if network != "tcp" && network != "tcp4" {
|
return ipv6.ProtocolNumber
|
||||||
return nil, errors.New("only tcp is supported")
|
}
|
||||||
|
return ipv4.ProtocolNumber
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, err := net.ResolveTCPAddr(network, address)
|
// DialContext dials the provided address.
|
||||||
|
func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
switch network {
|
||||||
|
case "udp", "udp4", "udp6":
|
||||||
|
addr, err := net.ResolveUDPAddr(network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
fullAddr := tcpip.FullAddress{
|
fullAddr := tcpip.FullAddress{
|
||||||
NIC: nicID,
|
NIC: nicID,
|
||||||
Addr: tcpip.AddrFromSlice(addr.IP),
|
Addr: tcpip.AddrFromSlice(addr.IP),
|
||||||
Port: uint16(addr.Port),
|
Port: uint16(addr.Port),
|
||||||
}
|
}
|
||||||
|
num := getProtocolNumber(addr.AddrPort().Addr())
|
||||||
|
return gonet.DialUDP(s.ipstack, nil, &fullAddr, num)
|
||||||
|
case "tcp", "tcp4", "tcp6":
|
||||||
|
addr, err := net.ResolveTCPAddr(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fullAddr := tcpip.FullAddress{
|
||||||
|
NIC: nicID,
|
||||||
|
Addr: tcpip.AddrFromSlice(addr.IP),
|
||||||
|
Port: uint16(addr.Port),
|
||||||
|
}
|
||||||
|
num := getProtocolNumber(addr.AddrPort().Addr())
|
||||||
|
return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, num)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown network type: %s", network)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber)
|
// Dial dials the provided address
|
||||||
|
func (s *Service) Dial(network, address string) (net.Conn, error) {
|
||||||
|
return s.DialContext(context.Background(), network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen listens on the provided address. Currently only TCP with wildcard
|
// Listen listens on the provided address. Currently only TCP with wildcard
|
||||||
|
|
Loading…
Reference in New Issue