diff options
| -rw-r--r-- | listen_unix.go | 4 | ||||
| -rw-r--r-- | listeners.go | 13 | ||||
| -rw-r--r-- | modules/caddyhttp/reverseproxy/addresses.go | 12 | ||||
| -rw-r--r-- | modules/caddyhttp/reverseproxy/addresses_test.go | 244 | 
4 files changed, 255 insertions, 18 deletions
| diff --git a/listen_unix.go b/listen_unix.go index dc955d8..7ea6745 100644 --- a/listen_unix.go +++ b/listen_unix.go @@ -34,7 +34,7 @@ import (  // reuseUnixSocket copies and reuses the unix domain socket (UDS) if we already  // have it open; if not, unlink it so we can have it. No-op if not a unix network.  func reuseUnixSocket(network, addr string) (any, error) { -	if !isUnixNetwork(network) { +	if !IsUnixNetwork(network) {  		return nil, nil  	} @@ -103,7 +103,7 @@ func listenTCPOrUnix(ctx context.Context, lnKey string, network, address string,  // reusePort sets SO_REUSEPORT. Ineffective for unix sockets.  func reusePort(network, address string, conn syscall.RawConn) error { -	if isUnixNetwork(network) { +	if IsUnixNetwork(network) {  		return nil  	}  	return conn.Control(func(descriptor uintptr) { diff --git a/listeners.go b/listeners.go index 5bf85a0..f922144 100644 --- a/listeners.go +++ b/listeners.go @@ -205,7 +205,7 @@ func (na NetworkAddress) listen(ctx context.Context, portOffset uint, config net  // IsUnixNetwork returns true if na.Network is  // unix, unixgram, or unixpacket.  func (na NetworkAddress) IsUnixNetwork() bool { -	return isUnixNetwork(na.Network) +	return IsUnixNetwork(na.Network)  }  // JoinHostPort is like net.JoinHostPort, but where the port @@ -289,8 +289,9 @@ func (na NetworkAddress) String() string {  	return JoinNetworkAddress(na.Network, na.Host, na.port())  } -func isUnixNetwork(netw string) bool { -	return netw == "unix" || netw == "unixgram" || netw == "unixpacket" +// IsUnixNetwork returns true if the netw is a unix network. +func IsUnixNetwork(netw string) bool { +	return strings.HasPrefix(netw, "unix")  }  // ParseNetworkAddress parses addr into its individual @@ -310,7 +311,7 @@ func ParseNetworkAddress(addr string) (NetworkAddress, error) {  	if network == "" {  		network = "tcp"  	} -	if isUnixNetwork(network) { +	if IsUnixNetwork(network) {  		return NetworkAddress{  			Network: network,  			Host:    host, @@ -353,7 +354,7 @@ func SplitNetworkAddress(a string) (network, host, port string, err error) {  		network = strings.ToLower(strings.TrimSpace(beforeSlash))  		a = afterSlash  	} -	if isUnixNetwork(network) { +	if IsUnixNetwork(network) {  		host = a  		return  	} @@ -384,7 +385,7 @@ func JoinNetworkAddress(network, host, port string) string {  	if network != "" {  		a = network + "/"  	} -	if (host != "" && port == "") || isUnixNetwork(network) { +	if (host != "" && port == "") || IsUnixNetwork(network) {  		a += host  	} else if port != "" {  		a += net.JoinHostPort(host, port) diff --git a/modules/caddyhttp/reverseproxy/addresses.go b/modules/caddyhttp/reverseproxy/addresses.go index 4da47fb..8152108 100644 --- a/modules/caddyhttp/reverseproxy/addresses.go +++ b/modules/caddyhttp/reverseproxy/addresses.go @@ -27,9 +27,6 @@ import (  // the dial address, including support for a scheme in front  // as a shortcut for the port number, and a network type,  // for example 'unix' to dial a unix socket. -// -// TODO: the logic in this function is kind of sensitive, we -// need to write tests before making any more changes to it  func parseUpstreamDialAddress(upstreamAddr string) (string, string, error) {  	var network, scheme, host, port string @@ -79,19 +76,14 @@ func parseUpstreamDialAddress(upstreamAddr string) (string, string, error) {  		scheme, host, port = toURL.Scheme, toURL.Hostname(), toURL.Port()  	} else { -		// extract network manually, since caddy.ParseNetworkAddress() will always add one -		if beforeSlash, afterSlash, slashFound := strings.Cut(upstreamAddr, "/"); slashFound { -			network = strings.ToLower(strings.TrimSpace(beforeSlash)) -			upstreamAddr = afterSlash -		}  		var err error -		host, port, err = net.SplitHostPort(upstreamAddr) +		network, host, port, err = caddy.SplitNetworkAddress(upstreamAddr)  		if err != nil {  			host = upstreamAddr  		}  		// we can assume a port if only a hostname is specified, but use of a  		// placeholder without a port likely means a port will be filled in -		if port == "" && !strings.Contains(host, "{") { +		if port == "" && !strings.Contains(host, "{") && !caddy.IsUnixNetwork(network) {  			port = "80"  		}  	} diff --git a/modules/caddyhttp/reverseproxy/addresses_test.go b/modules/caddyhttp/reverseproxy/addresses_test.go new file mode 100644 index 0000000..6355c75 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/addresses_test.go @@ -0,0 +1,244 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +//	http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import "testing" + +func TestParseUpstreamDialAddress(t *testing.T) { +	for i, tc := range []struct { +		input          string +		expectHostPort string +		expectScheme   string +		expectErr      bool +	}{ +		{ +			input:          "foo", +			expectHostPort: "foo:80", +		}, +		{ +			input:          "foo:1234", +			expectHostPort: "foo:1234", +		}, +		{ +			input:          "127.0.0.1", +			expectHostPort: "127.0.0.1:80", +		}, +		{ +			input:          "127.0.0.1:1234", +			expectHostPort: "127.0.0.1:1234", +		}, +		{ +			input:          "[::1]", +			expectHostPort: "[::1]:80", +		}, +		{ +			input:          "[::1]:1234", +			expectHostPort: "[::1]:1234", +		}, +		{ +			input:          "{foo}", +			expectHostPort: "{foo}", +		}, +		{ +			input:          "{foo}:80", +			expectHostPort: "{foo}:80", +		}, +		{ +			input:          "{foo}:{bar}", +			expectHostPort: "{foo}:{bar}", +		}, +		{ +			input:          "http://foo", +			expectHostPort: "foo:80", +			expectScheme:   "http", +		}, +		{ +			input:          "http://foo:1234", +			expectHostPort: "foo:1234", +			expectScheme:   "http", +		}, +		{ +			input:          "http://127.0.0.1", +			expectHostPort: "127.0.0.1:80", +			expectScheme:   "http", +		}, +		{ +			input:          "http://127.0.0.1:1234", +			expectHostPort: "127.0.0.1:1234", +			expectScheme:   "http", +		}, +		{ +			input:          "http://[::1]", +			expectHostPort: "[::1]:80", +			expectScheme:   "http", +		}, +		{ +			input:          "http://[::1]:80", +			expectHostPort: "[::1]:80", +			expectScheme:   "http", +		}, +		{ +			input:          "https://foo", +			expectHostPort: "foo:443", +			expectScheme:   "https", +		}, +		{ +			input:          "https://foo:1234", +			expectHostPort: "foo:1234", +			expectScheme:   "https", +		}, +		{ +			input:          "https://127.0.0.1", +			expectHostPort: "127.0.0.1:443", +			expectScheme:   "https", +		}, +		{ +			input:          "https://127.0.0.1:1234", +			expectHostPort: "127.0.0.1:1234", +			expectScheme:   "https", +		}, +		{ +			input:          "https://[::1]", +			expectHostPort: "[::1]:443", +			expectScheme:   "https", +		}, +		{ +			input:          "https://[::1]:1234", +			expectHostPort: "[::1]:1234", +			expectScheme:   "https", +		}, +		{ +			input:          "h2c://foo", +			expectHostPort: "foo:80", +			expectScheme:   "h2c", +		}, +		{ +			input:          "h2c://foo:1234", +			expectHostPort: "foo:1234", +			expectScheme:   "h2c", +		}, +		{ +			input:          "h2c://127.0.0.1", +			expectHostPort: "127.0.0.1:80", +			expectScheme:   "h2c", +		}, +		{ +			input:          "h2c://127.0.0.1:1234", +			expectHostPort: "127.0.0.1:1234", +			expectScheme:   "h2c", +		}, +		{ +			input:          "h2c://[::1]", +			expectHostPort: "[::1]:80", +			expectScheme:   "h2c", +		}, +		{ +			input:          "h2c://[::1]:1234", +			expectHostPort: "[::1]:1234", +			expectScheme:   "h2c", +		}, +		{ +			input:          "unix//var/php.sock", +			expectHostPort: "unix//var/php.sock", +		}, +		{ +			input:          "unix+h2c//var/grpc.sock", +			expectHostPort: "unix//var/grpc.sock", +			expectScheme:   "h2c", +		}, +		{ +			input:          "unix/{foo}", +			expectHostPort: "unix/{foo}", +		}, +		{ +			input:          "unix+h2c/{foo}", +			expectHostPort: "unix/{foo}", +			expectScheme:   "h2c", +		}, +		{ +			input:          "unix//foo/{foo}/bar", +			expectHostPort: "unix//foo/{foo}/bar", +		}, +		{ +			input:          "unix+h2c//foo/{foo}/bar", +			expectHostPort: "unix//foo/{foo}/bar", +			expectScheme:   "h2c", +		}, +		{ +			input:     "http://{foo}", +			expectErr: true, +		}, +		{ +			input:     "http:// :80", +			expectErr: true, +		}, +		{ +			input:     "http://localhost/path", +			expectErr: true, +		}, +		{ +			input:     "http://localhost?key=value", +			expectErr: true, +		}, +		{ +			input:     "http://localhost#fragment", +			expectErr: true, +		}, +		{ +			input:     "http://foo:443", +			expectErr: true, +		}, +		{ +			input:     "https://foo:80", +			expectErr: true, +		}, +		{ +			input:     "h2c://foo:443", +			expectErr: true, +		}, +		{ +			input:          `unix/c:\absolute\path`, +			expectHostPort: `unix/c:\absolute\path`, +		}, +		{ +			input:          `unix+h2c/c:\absolute\path`, +			expectHostPort: `unix/c:\absolute\path`, +			expectScheme:   "h2c", +		}, +		{ +			input:          "unix/c:/absolute/path", +			expectHostPort: "unix/c:/absolute/path", +		}, +		{ +			input:          "unix+h2c/c:/absolute/path", +			expectHostPort: "unix/c:/absolute/path", +			expectScheme:   "h2c", +		}, +	} { +		actualHostPort, actualScheme, err := parseUpstreamDialAddress(tc.input) +		if tc.expectErr && err == nil { +			t.Errorf("Test %d: Expected error but got %v", i, err) +		} +		if !tc.expectErr && err != nil { +			t.Errorf("Test %d: Expected no error but got %v", i, err) +		} +		if actualHostPort != tc.expectHostPort { +			t.Errorf("Test %d: Expected host and port '%s' but got '%s'", i, tc.expectHostPort, actualHostPort) +		} +		if actualScheme != tc.expectScheme { +			t.Errorf("Test %d: Expected scheme '%s' but got '%s'", i, tc.expectScheme, actualScheme) +		} +	} +} | 
