diff --git a/udp/udp_linux.go b/udp/udp_linux.go index ca050bb..1151c89 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -22,6 +22,7 @@ import ( type StdConn struct { sysFd int + isV4 bool l *logrus.Logger batch int } @@ -45,9 +46,22 @@ const ( type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32 +func maybeIPV4(ip net.IP) (net.IP, bool) { + ip4 := ip.To4() + if ip4 != nil { + return ip4, true + } + return ip, false +} + func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { + ipV4, isV4 := maybeIPV4(ip) + af := unix.AF_INET6 + if isV4 { + af = unix.AF_INET + } syscall.ForkLock.RLock() - fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP) + fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP) if err == nil { unix.CloseOnExec(fd) } @@ -58,9 +72,6 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) ( return nil, fmt.Errorf("unable to open socket: %s", err) } - var lip [16]byte - copy(lip[:], ip.To16()) - if multi { if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) @@ -68,7 +79,17 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) ( } //TODO: support multiple listening IPs (for limiting ipv6) - if err = unix.Bind(fd, &unix.SockaddrInet6{Addr: lip, Port: port}); err != nil { + var sa unix.Sockaddr + if isV4 { + sa4 := &unix.SockaddrInet4{Port: port} + copy(sa4.Addr[:], ipV4) + sa = sa4 + } else { + sa6 := &unix.SockaddrInet6{Port: port} + copy(sa6.Addr[:], ip.To16()) + sa = sa6 + } + if err = unix.Bind(fd, sa); err != nil { return nil, fmt.Errorf("unable to bind to socket: %s", err) } @@ -77,7 +98,7 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) ( //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //l.Println(v, err) - return &StdConn{sysFd: fd, l: l, batch: batch}, err + return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err } func (u *StdConn) Rebind() error { @@ -143,7 +164,11 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew //metric.Update(int64(n)) for i := 0; i < n; i++ { - udpAddr.IP = names[i][8:24] + if u.isV4 { + udpAddr.IP = names[i][4:8] + } else { + udpAddr.IP = names[i][8:24] + } udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) } @@ -192,13 +217,18 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { } func (u *StdConn) WriteTo(b []byte, addr *Addr) error { + if u.isV4 { + return u.writeTo4(b, addr) + } + return u.writeTo6(b, addr) +} +func (u *StdConn) writeTo6(b []byte, addr *Addr) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 - p := (*[2]byte)(unsafe.Pointer(&rsa.Port)) - p[0] = byte(addr.Port >> 8) - p[1] = byte(addr.Port) - copy(rsa.Addr[:], addr.IP) + // Little Endian -> Network Endian + rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) + copy(rsa.Addr[:], addr.IP.To16()) for { _, _, err := unix.Syscall6( @@ -221,6 +251,39 @@ func (u *StdConn) WriteTo(b []byte, addr *Addr) error { } } +func (u *StdConn) writeTo4(b []byte, addr *Addr) error { + addrV4, isAddrV4 := maybeIPV4(addr.IP) + if !isAddrV4 { + return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") + } + + var rsa unix.RawSockaddrInet4 + rsa.Family = unix.AF_INET + // Little Endian -> Network Endian + rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) + copy(rsa.Addr[:], addrV4) + + for { + _, _, err := unix.Syscall6( + unix.SYS_SENDTO, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&b[0])), + uintptr(len(b)), + uintptr(0), + uintptr(unsafe.Pointer(&rsa)), + uintptr(unix.SizeofSockaddrInet4), + ) + + if err != 0 { + return &net.OpError{Op: "sendto", Err: err} + } + + //TODO: handle incomplete writes + + return nil + } +} + func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 {