From 93bc1b72e3cd566e6447ad7a1f832474aad5dfcc Mon Sep 17 00:00:00 2001 From: Mohammed Al Sahaf Date: Tue, 12 Nov 2019 01:33:38 +0300 Subject: core: Use port ranges to avoid OOM with bad inputs (#2859) * fix OOM issue caught by fuzzing * use ParsedAddress as the struct name for the result of ParseNetworkAddress * simplify code using the ParsedAddress type * minor cleanups --- listeners.go | 101 ++++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 75 insertions(+), 26 deletions(-) (limited to 'listeners.go') diff --git a/listeners.go b/listeners.go index 4464b78..37b4c29 100644 --- a/listeners.go +++ b/listeners.go @@ -257,52 +257,94 @@ type globalListener struct { pc net.PacketConn } -var ( - listeners = make(map[string]*globalListener) - listenersMu sync.Mutex -) +// ParsedAddress contains the individual components +// for a parsed network address of the form accepted +// by ParseNetworkAddress(). Network should be a +// network value accepted by Go's net package. Port +// ranges are given by [StartPort, EndPort]. +type ParsedAddress struct { + Network string + Host string + StartPort uint + EndPort uint +} + +// JoinHostPort is like net.JoinHostPort, but where the port +// is StartPort + offset. +func (l ParsedAddress) JoinHostPort(offset uint) string { + return net.JoinHostPort(l.Host, strconv.Itoa(int(l.StartPort+offset))) +} -// ParseNetworkAddress parses addr, a string of the form "network/host:port" -// (with any part optional) into its component parts. Because a port can -// also be a port range, there may be multiple addresses returned. -func ParseNetworkAddress(addr string) (network string, addrs []string, err error) { +// PortRangeSize returns how many ports are in +// pa's port range. Port ranges are inclusive, +// so the size is the difference of start and +// end ports plus one. +func (pa ParsedAddress) PortRangeSize() uint { + return (pa.EndPort - pa.StartPort) + 1 +} + +// String reconstructs the address string to the form expected +// by ParseNetworkAddress(). +func (pa ParsedAddress) String() string { + port := strconv.FormatUint(uint64(pa.StartPort), 10) + if pa.StartPort != pa.EndPort { + port += "-" + strconv.FormatUint(uint64(pa.EndPort), 10) + } + return JoinNetworkAddress(pa.Network, pa.Host, port) +} + +// ParseNetworkAddress parses addr into its individual +// components. The input string is expected to be of +// the form "network/host:port-range" where any part is +// optional. The default network, if unspecified, is tcp. +// Port ranges are inclusive. +// +// Network addresses are distinct from URLs and do not +// use URL syntax. +func ParseNetworkAddress(addr string) (ParsedAddress, error) { var host, port string - network, host, port, err = SplitNetworkAddress(addr) + network, host, port, err := SplitNetworkAddress(addr) if network == "" { network = "tcp" } if err != nil { - return + return ParsedAddress{}, err } if network == "unix" || network == "unixgram" || network == "unixpacket" { - addrs = []string{host} - return + return ParsedAddress{ + Network: network, + Host: host, + }, nil } ports := strings.SplitN(port, "-", 2) if len(ports) == 1 { ports = append(ports, ports[0]) } - var start, end int - start, err = strconv.Atoi(ports[0]) + var start, end uint64 + start, err = strconv.ParseUint(ports[0], 10, 16) if err != nil { - return + return ParsedAddress{}, fmt.Errorf("invalid start port: %v", err) } - end, err = strconv.Atoi(ports[1]) + end, err = strconv.ParseUint(ports[1], 10, 16) if err != nil { - return + return ParsedAddress{}, fmt.Errorf("invalid end port: %v", err) } if end < start { - err = fmt.Errorf("end port must be greater than start port") - return + return ParsedAddress{}, fmt.Errorf("end port must not be less than start port") } - for p := start; p <= end; p++ { - addrs = append(addrs, net.JoinHostPort(host, fmt.Sprintf("%d", p))) + if (end - start) > maxPortSpan { + return ParsedAddress{}, fmt.Errorf("port range exceeds %d ports", maxPortSpan) } - return + return ParsedAddress{ + Network: network, + Host: host, + StartPort: uint(start), + EndPort: uint(end), + }, nil } // SplitNetworkAddress splits a into its network, host, and port components. -// Note that port may be a port range, or omitted for unix sockets. +// Note that port may be a port range (:X-Y), or omitted for unix sockets. func SplitNetworkAddress(a string) (network, host, port string, err error) { if idx := strings.Index(a, "/"); idx >= 0 { network = strings.ToLower(strings.TrimSpace(a[:idx])) @@ -317,9 +359,9 @@ func SplitNetworkAddress(a string) (network, host, port string, err error) { } // JoinNetworkAddress combines network, host, and port into a single -// address string of the form "network/host:port". Port may be a -// port range. For unix sockets, the network should be "unix" and -// the path to the socket should be given in the host argument. +// address string of the form accepted by ParseNetworkAddress(). For unix sockets, the network +// should be "unix" and the path to the socket should be given as the +// host parameter. func JoinNetworkAddress(network, host, port string) string { var a string if network != "" { @@ -332,3 +374,10 @@ func JoinNetworkAddress(network, host, port string) string { } return a } + +var ( + listeners = make(map[string]*globalListener) + listenersMu sync.Mutex +) + +const maxPortSpan = 65535 -- cgit v1.2.3