summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMohammed Al Sahaf <msaa1990@gmail.com>2019-11-12 01:33:38 +0300
committerMatt Holt <mholt@users.noreply.github.com>2019-11-11 15:33:38 -0700
commit93bc1b72e3cd566e6447ad7a1f832474aad5dfcc (patch)
tree05ddeb324261d7058925948baa0077752fd5e453
parenta19da07b72d84432341990bcedce511fe2f980da (diff)
core: Use port ranges to avoid OOM with bad inputs (#2859)
* fix OOM issue caught by fuzzing * use ParsedAddress as the struct name for the result of ParseNetworkAddress * simplify code using the ParsedAddress type * minor cleanups
-rw-r--r--admin.go16
-rw-r--r--listeners.go101
-rw-r--r--listeners_fuzz.go2
-rw-r--r--listeners_test.go105
-rw-r--r--modules/caddyhttp/caddyhttp.go33
-rw-r--r--modules/caddyhttp/reverseproxy/healthchecks.go10
-rw-r--r--modules/caddyhttp/reverseproxy/hosts.go24
-rw-r--r--modules/caddyhttp/server.go40
8 files changed, 201 insertions, 130 deletions
diff --git a/admin.go b/admin.go
index 502a968..b1ced18 100644
--- a/admin.go
+++ b/admin.go
@@ -48,23 +48,19 @@ type AdminConfig struct {
// listenAddr extracts a singular listen address from ac.Listen,
// returning the network and the address of the listener.
-func (admin AdminConfig) listenAddr() (netw string, addr string, err error) {
- var listenAddrs []string
+func (admin AdminConfig) listenAddr() (string, string, error) {
input := admin.Listen
if input == "" {
input = DefaultAdminListen
}
- netw, listenAddrs, err = ParseNetworkAddress(input)
+ listenAddr, err := ParseNetworkAddress(input)
if err != nil {
- err = fmt.Errorf("parsing admin listener address: %v", err)
- return
+ return "", "", fmt.Errorf("parsing admin listener address: %v", err)
}
- if len(listenAddrs) != 1 {
- err = fmt.Errorf("admin endpoint must have exactly one address; cannot listen on %v", listenAddrs)
- return
+ if listenAddr.PortRangeSize() != 1 {
+ return "", "", fmt.Errorf("admin endpoint must have exactly one address; cannot listen on %v", listenAddr)
}
- addr = listenAddrs[0]
- return
+ return listenAddr.Network, listenAddr.JoinHostPort(0), nil
}
// newAdminHandler reads admin's config and returns an http.Handler suitable
diff --git a/listeners.go b/listeners.go
index 4464b78..37b4c29 100644
--- a/listeners.go
+++ b/listeners.go
@@ -257,52 +257,94 @@ type globalListener struct {
pc net.PacketConn
}
-var (
- listeners = make(map[string]*globalListener)
- listenersMu sync.Mutex
-)
+// ParsedAddress contains the individual components
+// for a parsed network address of the form accepted
+// by ParseNetworkAddress(). Network should be a
+// network value accepted by Go's net package. Port
+// ranges are given by [StartPort, EndPort].
+type ParsedAddress struct {
+ Network string
+ Host string
+ StartPort uint
+ EndPort uint
+}
+
+// JoinHostPort is like net.JoinHostPort, but where the port
+// is StartPort + offset.
+func (l ParsedAddress) JoinHostPort(offset uint) string {
+ return net.JoinHostPort(l.Host, strconv.Itoa(int(l.StartPort+offset)))
+}
-// ParseNetworkAddress parses addr, a string of the form "network/host:port"
-// (with any part optional) into its component parts. Because a port can
-// also be a port range, there may be multiple addresses returned.
-func ParseNetworkAddress(addr string) (network string, addrs []string, err error) {
+// PortRangeSize returns how many ports are in
+// pa's port range. Port ranges are inclusive,
+// so the size is the difference of start and
+// end ports plus one.
+func (pa ParsedAddress) PortRangeSize() uint {
+ return (pa.EndPort - pa.StartPort) + 1
+}
+
+// String reconstructs the address string to the form expected
+// by ParseNetworkAddress().
+func (pa ParsedAddress) String() string {
+ port := strconv.FormatUint(uint64(pa.StartPort), 10)
+ if pa.StartPort != pa.EndPort {
+ port += "-" + strconv.FormatUint(uint64(pa.EndPort), 10)
+ }
+ return JoinNetworkAddress(pa.Network, pa.Host, port)
+}
+
+// ParseNetworkAddress parses addr into its individual
+// components. The input string is expected to be of
+// the form "network/host:port-range" where any part is
+// optional. The default network, if unspecified, is tcp.
+// Port ranges are inclusive.
+//
+// Network addresses are distinct from URLs and do not
+// use URL syntax.
+func ParseNetworkAddress(addr string) (ParsedAddress, error) {
var host, port string
- network, host, port, err = SplitNetworkAddress(addr)
+ network, host, port, err := SplitNetworkAddress(addr)
if network == "" {
network = "tcp"
}
if err != nil {
- return
+ return ParsedAddress{}, err
}
if network == "unix" || network == "unixgram" || network == "unixpacket" {
- addrs = []string{host}
- return
+ return ParsedAddress{
+ Network: network,
+ Host: host,
+ }, nil
}
ports := strings.SplitN(port, "-", 2)
if len(ports) == 1 {
ports = append(ports, ports[0])
}
- var start, end int
- start, err = strconv.Atoi(ports[0])
+ var start, end uint64
+ start, err = strconv.ParseUint(ports[0], 10, 16)
if err != nil {
- return
+ return ParsedAddress{}, fmt.Errorf("invalid start port: %v", err)
}
- end, err = strconv.Atoi(ports[1])
+ end, err = strconv.ParseUint(ports[1], 10, 16)
if err != nil {
- return
+ return ParsedAddress{}, fmt.Errorf("invalid end port: %v", err)
}
if end < start {
- err = fmt.Errorf("end port must be greater than start port")
- return
+ return ParsedAddress{}, fmt.Errorf("end port must not be less than start port")
}
- for p := start; p <= end; p++ {
- addrs = append(addrs, net.JoinHostPort(host, fmt.Sprintf("%d", p)))
+ if (end - start) > maxPortSpan {
+ return ParsedAddress{}, fmt.Errorf("port range exceeds %d ports", maxPortSpan)
}
- return
+ return ParsedAddress{
+ Network: network,
+ Host: host,
+ StartPort: uint(start),
+ EndPort: uint(end),
+ }, nil
}
// SplitNetworkAddress splits a into its network, host, and port components.
-// Note that port may be a port range, or omitted for unix sockets.
+// Note that port may be a port range (:X-Y), or omitted for unix sockets.
func SplitNetworkAddress(a string) (network, host, port string, err error) {
if idx := strings.Index(a, "/"); idx >= 0 {
network = strings.ToLower(strings.TrimSpace(a[:idx]))
@@ -317,9 +359,9 @@ func SplitNetworkAddress(a string) (network, host, port string, err error) {
}
// JoinNetworkAddress combines network, host, and port into a single
-// address string of the form "network/host:port". Port may be a
-// port range. For unix sockets, the network should be "unix" and
-// the path to the socket should be given in the host argument.
+// address string of the form accepted by ParseNetworkAddress(). For unix sockets, the network
+// should be "unix" and the path to the socket should be given as the
+// host parameter.
func JoinNetworkAddress(network, host, port string) string {
var a string
if network != "" {
@@ -332,3 +374,10 @@ func JoinNetworkAddress(network, host, port string) string {
}
return a
}
+
+var (
+ listeners = make(map[string]*globalListener)
+ listenersMu sync.Mutex
+)
+
+const maxPortSpan = 65535
diff --git a/listeners_fuzz.go b/listeners_fuzz.go
index 98465fd..826c57e 100644
--- a/listeners_fuzz.go
+++ b/listeners_fuzz.go
@@ -18,7 +18,7 @@
package caddy
func FuzzParseNetworkAddress(data []byte) int {
- _, _, err := ParseNetworkAddress(string(data))
+ _, err := ParseNetworkAddress(string(data))
if err != nil {
return 0
}
diff --git a/listeners_test.go b/listeners_test.go
index bdddf32..076b365 100644
--- a/listeners_test.go
+++ b/listeners_test.go
@@ -152,74 +152,101 @@ func TestJoinNetworkAddress(t *testing.T) {
func TestParseNetworkAddress(t *testing.T) {
for i, tc := range []struct {
- input string
- expectNetwork string
- expectAddrs []string
- expectErr bool
+ input string
+ expectAddr ParsedAddress
+ expectErr bool
}{
{
- input: "",
- expectNetwork: "tcp",
- expectErr: true,
+ input: "",
+ expectErr: true,
},
{
- input: ":",
- expectNetwork: "tcp",
- expectErr: true,
+ input: ":",
+ expectErr: true,
},
{
- input: ":1234",
- expectNetwork: "tcp",
- expectAddrs: []string{":1234"},
+ input: ":1234",
+ expectAddr: ParsedAddress{
+ Network: "tcp",
+ Host: "",
+ StartPort: 1234,
+ EndPort: 1234,
+ },
},
{
- input: "tcp/:1234",
- expectNetwork: "tcp",
- expectAddrs: []string{":1234"},
+ input: "tcp/:1234",
+ expectAddr: ParsedAddress{
+ Network: "tcp",
+ Host: "",
+ StartPort: 1234,
+ EndPort: 1234,
+ },
},
{
- input: "tcp6/:1234",
- expectNetwork: "tcp6",
- expectAddrs: []string{":1234"},
+ input: "tcp6/:1234",
+ expectAddr: ParsedAddress{
+ Network: "tcp6",
+ Host: "",
+ StartPort: 1234,
+ EndPort: 1234,
+ },
},
{
- input: "tcp4/localhost:1234",
- expectNetwork: "tcp4",
- expectAddrs: []string{"localhost:1234"},
+ input: "tcp4/localhost:1234",
+ expectAddr: ParsedAddress{
+ Network: "tcp4",
+ Host: "localhost",
+ StartPort: 1234,
+ EndPort: 1234,
+ },
},
{
- input: "unix//foo/bar",
- expectNetwork: "unix",
- expectAddrs: []string{"/foo/bar"},
+ input: "unix//foo/bar",
+ expectAddr: ParsedAddress{
+ Network: "unix",
+ Host: "/foo/bar",
+ },
},
{
- input: "localhost:1234-1234",
- expectNetwork: "tcp",
- expectAddrs: []string{"localhost:1234"},
+ input: "localhost:1234-1234",
+ expectAddr: ParsedAddress{
+ Network: "tcp",
+ Host: "localhost",
+ StartPort: 1234,
+ EndPort: 1234,
+ },
},
{
- input: "localhost:2-1",
- expectNetwork: "tcp",
- expectErr: true,
+ input: "localhost:2-1",
+ expectErr: true,
},
{
- input: "localhost:0",
- expectNetwork: "tcp",
- expectAddrs: []string{"localhost:0"},
+ input: "localhost:0",
+ expectAddr: ParsedAddress{
+ Network: "tcp",
+ Host: "localhost",
+ StartPort: 0,
+ EndPort: 0,
+ },
+ },
+ {
+ input: "localhost:1-999999999999",
+ expectErr: true,
},
} {
- actualNetwork, actualAddrs, err := ParseNetworkAddress(tc.input)
+ actualAddr, err := ParseNetworkAddress(tc.input)
if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got: %v", i, err)
}
if !tc.expectErr && err != nil {
t.Errorf("Test %d: Expected no error but got: %v", i, err)
}
- if actualNetwork != tc.expectNetwork {
- t.Errorf("Test %d: Expected network '%s' but got '%s'", i, tc.expectNetwork, actualNetwork)
+
+ if actualAddr.Network != tc.expectAddr.Network {
+ t.Errorf("Test %d: Expected network '%v' but got '%v'", i, tc.expectAddr, actualAddr)
}
- if !reflect.DeepEqual(tc.expectAddrs, actualAddrs) {
- t.Errorf("Test %d: Expected addresses %v but got %v", i, tc.expectAddrs, actualAddrs)
+ if !reflect.DeepEqual(tc.expectAddr, actualAddr) {
+ t.Errorf("Test %d: Expected addresses %v but got %v", i, tc.expectAddr, actualAddr)
}
}
}
diff --git a/modules/caddyhttp/caddyhttp.go b/modules/caddyhttp/caddyhttp.go
index 99a64c3..36d8154 100644
--- a/modules/caddyhttp/caddyhttp.go
+++ b/modules/caddyhttp/caddyhttp.go
@@ -135,15 +135,18 @@ func (app *App) Validate() error {
lnAddrs := make(map[string]string)
for srvName, srv := range app.Servers {
for _, addr := range srv.Listen {
- netw, expanded, err := caddy.ParseNetworkAddress(addr)
+ listenAddr, err := caddy.ParseNetworkAddress(addr)
if err != nil {
return fmt.Errorf("invalid listener address '%s': %v", addr, err)
}
- for _, a := range expanded {
- if sn, ok := lnAddrs[netw+a]; ok {
- return fmt.Errorf("server %s: listener address repeated: %s (already claimed by server '%s')", srvName, a, sn)
+ // check that every address in the port range is unique to this server;
+ // we do not use <= here because PortRangeSize() adds 1 to EndPort for us
+ for i := uint(0); i < listenAddr.PortRangeSize(); i++ {
+ addr := caddy.JoinNetworkAddress(listenAddr.Network, listenAddr.Host, strconv.Itoa(int(listenAddr.StartPort+i)))
+ if sn, ok := lnAddrs[addr]; ok {
+ return fmt.Errorf("server %s: listener address repeated: %s (already claimed by server '%s')", srvName, addr, sn)
}
- lnAddrs[netw+a] = srvName
+ lnAddrs[addr] = srvName
}
}
}
@@ -176,14 +179,15 @@ func (app *App) Start() error {
}
for _, lnAddr := range srv.Listen {
- network, addrs, err := caddy.ParseNetworkAddress(lnAddr)
+ listenAddr, err := caddy.ParseNetworkAddress(lnAddr)
if err != nil {
return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err)
}
- for _, addr := range addrs {
- ln, err := caddy.Listen(network, addr)
+ for i := uint(0); i <= listenAddr.PortRangeSize(); i++ {
+ hostport := listenAddr.JoinHostPort(i)
+ ln, err := caddy.Listen(listenAddr.Network, hostport)
if err != nil {
- return fmt.Errorf("%s: listening on %s: %v", network, addr, err)
+ return fmt.Errorf("%s: listening on %s: %v", listenAddr.Network, hostport, err)
}
// enable HTTP/2 by default
@@ -194,11 +198,10 @@ func (app *App) Start() error {
}
// enable TLS
- _, port, _ := net.SplitHostPort(addr)
- if len(srv.TLSConnPolicies) > 0 && port != strconv.Itoa(app.httpPort()) {
+ if len(srv.TLSConnPolicies) > 0 && int(i) != app.httpPort() {
tlsCfg, err := srv.TLSConnPolicies.TLSConfig(app.ctx)
if err != nil {
- return fmt.Errorf("%s/%s: making TLS configuration: %v", network, addr, err)
+ return fmt.Errorf("%s/%s: making TLS configuration: %v", listenAddr.Network, hostport, err)
}
ln = tls.NewListener(ln, tlsCfg)
@@ -206,15 +209,15 @@ func (app *App) Start() error {
// TODO: HTTP/3 support is experimental for now
if srv.ExperimentalHTTP3 {
app.logger.Info("enabling experimental HTTP/3 listener",
- zap.String("addr", addr),
+ zap.String("addr", hostport),
)
- h3ln, err := caddy.ListenPacket("udp", addr)
+ h3ln, err := caddy.ListenPacket("udp", hostport)
if err != nil {
return fmt.Errorf("getting HTTP/3 UDP listener: %v", err)
}
h3srv := &http3.Server{
Server: &http.Server{
- Addr: addr,
+ Addr: hostport,
Handler: srv,
TLSConfig: tlsCfg,
},
diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go
index 56e97bc..92b3547 100644
--- a/modules/caddyhttp/reverseproxy/healthchecks.go
+++ b/modules/caddyhttp/reverseproxy/healthchecks.go
@@ -102,7 +102,7 @@ func (h *Handler) doActiveHealthChecksForAllHosts() {
host := value.(Host)
go func(networkAddr string, host Host) {
- network, addrs, err := caddy.ParseNetworkAddress(networkAddr)
+ addr, err := caddy.ParseNetworkAddress(networkAddr)
if err != nil {
h.HealthChecks.Active.logger.Error("bad network address",
zap.String("address", networkAddr),
@@ -110,20 +110,20 @@ func (h *Handler) doActiveHealthChecksForAllHosts() {
)
return
}
- if len(addrs) != 1 {
+ if addr.PortRangeSize() != 1 {
h.HealthChecks.Active.logger.Error("multiple addresses (upstream must map to only one address)",
zap.String("address", networkAddr),
)
return
}
- hostAddr := addrs[0]
- if network == "unix" || network == "unixgram" || network == "unixpacket" {
+ hostAddr := addr.JoinHostPort(0)
+ if addr.Network == "unix" || addr.Network == "unixgram" || addr.Network == "unixpacket" {
// this will be used as the Host portion of a http.Request URL, and
// paths to socket files would produce an error when creating URL,
// so use a fake Host value instead; unix sockets are usually local
hostAddr = "localhost"
}
- err = h.doActiveHealthCheck(DialInfo{Network: network, Address: addrs[0]}, hostAddr, host)
+ err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: hostAddr}, hostAddr, host)
if err != nil {
h.HealthChecks.Active.logger.Error("active health check failed",
zap.String("address", networkAddr),
diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go
index a16bed0..8bad7c2 100644
--- a/modules/caddyhttp/reverseproxy/hosts.go
+++ b/modules/caddyhttp/reverseproxy/hosts.go
@@ -16,8 +16,7 @@ package reverseproxy
import (
"fmt"
- "net"
- "strings"
+ "strconv"
"sync/atomic"
"github.com/caddyserver/caddy/v2"
@@ -193,27 +192,20 @@ func (di DialInfo) String() string {
// the given Replacer. Note that the returned value is not a pointer.
func fillDialInfo(upstream *Upstream, repl caddy.Replacer) (DialInfo, error) {
dial := repl.ReplaceAll(upstream.Dial, "")
- netw, addrs, err := caddy.ParseNetworkAddress(dial)
+ addr, err := caddy.ParseNetworkAddress(dial)
if err != nil {
return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", upstream.Dial, dial, err)
}
- if len(addrs) != 1 {
+ if numPorts := addr.PortRangeSize(); numPorts != 1 {
return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d",
- upstream.Dial, dial, len(addrs))
- }
- var dialHost, dialPort string
- if !strings.Contains(netw, "unix") {
- dialHost, dialPort, err = net.SplitHostPort(addrs[0])
- if err != nil {
- dialHost = addrs[0] // assume there was no port
- }
+ upstream.Dial, dial, numPorts)
}
return DialInfo{
Upstream: upstream,
- Network: netw,
- Address: addrs[0],
- Host: dialHost,
- Port: dialPort,
+ Network: addr.Network,
+ Address: addr.JoinHostPort(0),
+ Host: addr.Host,
+ Port: strconv.Itoa(int(addr.StartPort)),
}, nil
}
diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go
index e119c2d..17860ed 100644
--- a/modules/caddyhttp/server.go
+++ b/modules/caddyhttp/server.go
@@ -242,40 +242,44 @@ func (s *Server) enforcementHandler(w http.ResponseWriter, r *http.Request, next
// listeners in s that use a port which is not otherPort.
func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool {
for _, lnAddr := range s.Listen {
- _, addrs, err := caddy.ParseNetworkAddress(lnAddr)
- if err == nil {
- for _, a := range addrs {
- _, port, err := net.SplitHostPort(a)
- if err == nil && port != strconv.Itoa(otherPort) {
- return true
- }
- }
+ laddrs, err := caddy.ParseNetworkAddress(lnAddr)
+ if err != nil {
+ continue
+ }
+ if uint(otherPort) > laddrs.EndPort || uint(otherPort) < laddrs.StartPort {
+ return true
}
}
return false
}
+// hasListenerAddress returns true if s has a listener
+// at the given address fullAddr. Currently, fullAddr
+// must represent exactly one socket address (port
+// ranges are not supported)
func (s *Server) hasListenerAddress(fullAddr string) bool {
- netw, addrs, err := caddy.ParseNetworkAddress(fullAddr)
+ laddrs, err := caddy.ParseNetworkAddress(fullAddr)
if err != nil {
return false
}
- if len(addrs) != 1 {
- return false
+ if laddrs.PortRangeSize() != 1 {
+ return false // TODO: support port ranges
}
- addr := addrs[0]
+
for _, lnAddr := range s.Listen {
- thisNetw, thisAddrs, err := caddy.ParseNetworkAddress(lnAddr)
+ thisAddrs, err := caddy.ParseNetworkAddress(lnAddr)
if err != nil {
continue
}
- if thisNetw != netw {
+ if thisAddrs.Network != laddrs.Network {
continue
}
- for _, a := range thisAddrs {
- if a == addr {
- return true
- }
+
+ // host must be the same and port must fall within port range
+ if (thisAddrs.Host == laddrs.Host) &&
+ (laddrs.StartPort <= thisAddrs.EndPort) &&
+ (laddrs.StartPort >= thisAddrs.StartPort) {
+ return true
}
}
return false