From a3bdc22234b75e9420f8810918072fa34732ffb7 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Fri, 10 Apr 2020 17:31:38 -0600 Subject: admin: Always enforce Host header checks With a simple heuristic for loopback addresses, we can enable this by default without adding unnecessary inconvenience. --- listeners.go | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) (limited to 'listeners.go') diff --git a/listeners.go b/listeners.go index e1fd48c..bfbe6dd 100644 --- a/listeners.go +++ b/listeners.go @@ -289,14 +289,31 @@ func (na NetworkAddress) PortRangeSize() uint { return (na.EndPort - na.StartPort) + 1 } +func (na NetworkAddress) isLoopback() bool { + if na.IsUnixNetwork() { + return true + } + if na.Host == "localhost" { + return true + } + if ip := net.ParseIP(na.Host); ip != nil { + return ip.IsLoopback() + } + return false +} + +func (na NetworkAddress) port() string { + if na.StartPort == na.EndPort { + return strconv.FormatUint(uint64(na.StartPort), 10) + } + return fmt.Sprintf("%d-%d", na.StartPort, na.EndPort) +} + // String reconstructs the address string to the form expected -// by ParseNetworkAddress(). +// by ParseNetworkAddress(). If the address is a unix socket, +// any non-zero port will be dropped. func (na NetworkAddress) String() string { - port := strconv.FormatUint(uint64(na.StartPort), 10) - if na.StartPort != na.EndPort { - port += "-" + strconv.FormatUint(uint64(na.EndPort), 10) - } - return JoinNetworkAddress(na.Network, na.Host, port) + return JoinNetworkAddress(na.Network, na.Host, na.port()) } func isUnixNetwork(netw string) bool { @@ -378,7 +395,7 @@ func JoinNetworkAddress(network, host, port string) string { if network != "" { a = network + "/" } - if host != "" && port == "" { + if (host != "" && port == "") || isUnixNetwork(network) { a += host } else if port != "" { a += net.JoinHostPort(host, port) -- cgit v1.2.3