diff options
-rw-r--r-- | caddyconfig/httpcaddyfile/addresses.go | 14 | ||||
-rw-r--r-- | caddyconfig/httpcaddyfile/httptype.go | 30 | ||||
-rw-r--r-- | caddytest/integration/caddyfile_test.go | 2 | ||||
-rw-r--r-- | listeners.go | 64 | ||||
-rw-r--r-- | listeners_test.go | 34 |
5 files changed, 99 insertions, 45 deletions
diff --git a/caddyconfig/httpcaddyfile/addresses.go b/caddyconfig/httpcaddyfile/addresses.go index c7923e8..03083d8 100644 --- a/caddyconfig/httpcaddyfile/addresses.go +++ b/caddyconfig/httpcaddyfile/addresses.go @@ -183,6 +183,8 @@ func (st *ServerType) consolidateAddrMappings(addrToServerBlocks map[string][]se return sbaddrs } +// listenerAddrsForServerBlockKey essentially converts the Caddyfile +// site addresses to Caddy listener addresses for each server block. func (st *ServerType) listenerAddrsForServerBlockKey(sblock serverBlock, key string, options map[string]interface{}) ([]string, error) { addr, err := ParseAddress(key) @@ -232,12 +234,14 @@ func (st *ServerType) listenerAddrsForServerBlockKey(sblock serverBlock, key str // use a map to prevent duplication listeners := make(map[string]struct{}) for _, host := range lnHosts { - addr, err := caddy.ParseNetworkAddress(host) - if err == nil && addr.IsUnixNetwork() { - listeners[host] = struct{}{} - } else { - listeners[host+":"+lnPort] = struct{}{} + // host can have network + host (e.g. "tcp6/localhost") but + // will/should not have port information because this usually + // comes from the bind directive, so we append the port + addr, err := caddy.ParseNetworkAddress(host + ":" + lnPort) + if err != nil { + return nil, fmt.Errorf("parsing network address: %v", err) } + listeners[addr.String()] = struct{}{} } // now turn map into list diff --git a/caddyconfig/httpcaddyfile/httptype.go b/caddyconfig/httpcaddyfile/httptype.go index 5d78244..b114059 100644 --- a/caddyconfig/httpcaddyfile/httptype.go +++ b/caddyconfig/httpcaddyfile/httptype.go @@ -58,22 +58,13 @@ func (st ServerType) Setup(inputServerBlocks []caddyfile.ServerBlock, gc := counter{new(int)} state := make(map[string]interface{}) - // load all the server blocks and associate them with a "pile" - // of config values; also prohibit duplicate keys because they - // can make a config confusing if more than one server block is - // chosen to handle a request - we actually will make each - // server block's route terminal so that only one will run - sbKeys := make(map[string]struct{}) + // load all the server blocks and associate them with a "pile" of config values originalServerBlocks := make([]serverBlock, 0, len(inputServerBlocks)) - for i, sblock := range inputServerBlocks { + for _, sblock := range inputServerBlocks { for j, k := range sblock.Keys { if j == 0 && strings.HasPrefix(k, "@") { return nil, warnings, fmt.Errorf("cannot define a matcher outside of a site block: '%s'", k) } - if _, ok := sbKeys[k]; ok { - return nil, warnings, fmt.Errorf("duplicate site address not allowed: '%s' in %v (site block %d, key %d)", k, sblock.Keys, i, j) - } - sbKeys[k] = struct{}{} } originalServerBlocks = append(originalServerBlocks, serverBlock{ block: sblock, @@ -420,6 +411,23 @@ func (st *ServerType) serversFromPairings( } for i, p := range pairings { + // detect ambiguous site definitions: server blocks which + // have the same host bound to the same interface (listener + // address), otherwise their routes will improperly be added + // to the same server (see issue #4635) + for j, sblock1 := range p.serverBlocks { + for _, key := range sblock1.block.Keys { + for k, sblock2 := range p.serverBlocks { + if k == j { + continue + } + if sliceContains(sblock2.block.Keys, key) { + return nil, fmt.Errorf("ambiguous site definition: %s", key) + } + } + } + } + srv := &caddyhttp.Server{ Listen: p.addresses, } diff --git a/caddytest/integration/caddyfile_test.go b/caddytest/integration/caddyfile_test.go index be85f4a..2758883 100644 --- a/caddytest/integration/caddyfile_test.go +++ b/caddytest/integration/caddyfile_test.go @@ -68,7 +68,7 @@ func TestDuplicateHosts(t *testing.T) { } `, "caddyfile", - "duplicate site address not allowed") + "ambiguous site definition") } func TestReadCookie(t *testing.T) { diff --git a/listeners.go b/listeners.go index f2d7e10..4c86e82 100644 --- a/listeners.go +++ b/listeners.go @@ -391,10 +391,13 @@ func (na NetworkAddress) port() string { return fmt.Sprintf("%d-%d", na.StartPort, na.EndPort) } -// String reconstructs the address string to the form expected -// by ParseNetworkAddress(). If the address is a unix socket, -// any non-zero port will be dropped. +// String reconstructs the address string for human display. +// The output can be parsed by ParseNetworkAddress(). If the +// address is a unix socket, any non-zero port will be dropped. func (na NetworkAddress) String() string { + if na.Network == "tcp" && (na.Host != "" || na.port() != "") { + na.Network = "" // omit default network value for brevity + } return JoinNetworkAddress(na.Network, na.Host, na.port()) } @@ -427,36 +430,38 @@ func isListenBindAddressAlreadyInUseError(err error) bool { func ParseNetworkAddress(addr string) (NetworkAddress, error) { var host, port string network, host, port, err := SplitNetworkAddress(addr) - if network == "" { - network = "tcp" - } if err != nil { return NetworkAddress{}, err } + if network == "" { + network = "tcp" + } if isUnixNetwork(network) { return NetworkAddress{ Network: network, Host: host, }, nil } - ports := strings.SplitN(port, "-", 2) - if len(ports) == 1 { - ports = append(ports, ports[0]) - } var start, end uint64 - start, err = strconv.ParseUint(ports[0], 10, 16) - if err != nil { - return NetworkAddress{}, fmt.Errorf("invalid start port: %v", err) - } - end, err = strconv.ParseUint(ports[1], 10, 16) - if err != nil { - return NetworkAddress{}, fmt.Errorf("invalid end port: %v", err) - } - if end < start { - return NetworkAddress{}, fmt.Errorf("end port must not be less than start port") - } - if (end - start) > maxPortSpan { - return NetworkAddress{}, fmt.Errorf("port range exceeds %d ports", maxPortSpan) + if port != "" { + ports := strings.SplitN(port, "-", 2) + if len(ports) == 1 { + ports = append(ports, ports[0]) + } + start, err = strconv.ParseUint(ports[0], 10, 16) + if err != nil { + return NetworkAddress{}, fmt.Errorf("invalid start port: %v", err) + } + end, err = strconv.ParseUint(ports[1], 10, 16) + if err != nil { + return NetworkAddress{}, fmt.Errorf("invalid end port: %v", err) + } + if end < start { + return NetworkAddress{}, fmt.Errorf("end port must not be less than start port") + } + if (end - start) > maxPortSpan { + return NetworkAddress{}, fmt.Errorf("port range exceeds %d ports", maxPortSpan) + } } return NetworkAddress{ Network: network, @@ -478,6 +483,19 @@ func SplitNetworkAddress(a string) (network, host, port string, err error) { return } host, port, err = net.SplitHostPort(a) + if err == nil || a == "" { + return + } + // in general, if there was an error, it was likely "missing port", + // so try adding a bogus port to take advantage of standard library's + // robust parser, then strip the artificial port before returning + // (don't overwrite original error though; might still be relevant) + var err2 error + host, port, err2 = net.SplitHostPort(a + ":0") + if err2 == nil { + err = nil + port = "" + } return } diff --git a/listeners_test.go b/listeners_test.go index b75e2dc..6b0f440 100644 --- a/listeners_test.go +++ b/listeners_test.go @@ -32,10 +32,25 @@ func TestSplitNetworkAddress(t *testing.T) { expectErr: true, }, { - input: "foo", + input: "foo", + expectHost: "foo", + }, + { + input: ":", // empty host & empty port + }, + { + input: "::", expectErr: true, }, { + input: "[::]", + expectHost: "::", + }, + { + input: ":1234", + expectPort: "1234", + }, + { input: "foo:1234", expectHost: "foo", expectPort: "1234", @@ -80,10 +95,10 @@ func TestSplitNetworkAddress(t *testing.T) { } { actualNetwork, actualHost, actualPort, err := SplitNetworkAddress(tc.input) if tc.expectErr && err == nil { - t.Errorf("Test %d: Expected error but got: %v", i, err) + t.Errorf("Test %d: Expected error but got %v", i, err) } if !tc.expectErr && err != nil { - t.Errorf("Test %d: Expected no error but got: %v", i, err) + t.Errorf("Test %d: Expected no error but got %v", i, err) } if actualNetwork != tc.expectNetwork { t.Errorf("Test %d: Expected network '%s' but got '%s'", i, tc.expectNetwork, actualNetwork) @@ -169,8 +184,17 @@ func TestParseNetworkAddress(t *testing.T) { expectErr: true, }, { - input: ":", - expectErr: true, + input: ":", + expectAddr: NetworkAddress{ + Network: "tcp", + }, + }, + { + input: "[::]", + expectAddr: NetworkAddress{ + Network: "tcp", + Host: "::", + }, }, { input: ":1234", |