diff options
author | Tom Barrett <tom@tombarrett.xyz> | 2023-11-01 17:57:48 +0100 |
---|---|---|
committer | Tom Barrett <tom@tombarrett.xyz> | 2023-11-01 18:11:33 +0100 |
commit | 240c3d1338415e5d82ef7ca0e52c4284be6441bd (patch) | |
tree | 4b0ee5d208c2cdffa78d65f1b0abe0ec85f15652 /modules/caddyhttp/reverseproxy | |
parent | 73e78ab226f21e6c6c68961af88c4ab9c746f4f4 (diff) | |
parent | 0e204b730aa2b1fa0835336b1117eff8c420f713 (diff) |
vbump to v2.7.5
Diffstat (limited to 'modules/caddyhttp/reverseproxy')
18 files changed, 2131 insertions, 970 deletions
diff --git a/modules/caddyhttp/reverseproxy/addresses.go b/modules/caddyhttp/reverseproxy/addresses.go index 8152108..82c1c79 100644 --- a/modules/caddyhttp/reverseproxy/addresses.go +++ b/modules/caddyhttp/reverseproxy/addresses.go @@ -23,11 +23,46 @@ import ( "github.com/caddyserver/caddy/v2" ) +type parsedAddr struct { + network, scheme, host, port string + valid bool +} + +func (p parsedAddr) dialAddr() string { + if !p.valid { + return "" + } + // for simplest possible config, we only need to include + // the network portion if the user specified one + if p.network != "" { + return caddy.JoinNetworkAddress(p.network, p.host, p.port) + } + + // if the host is a placeholder, then we don't want to join with an empty port, + // because that would just append an extra ':' at the end of the address. + if p.port == "" && strings.Contains(p.host, "{") { + return p.host + } + return net.JoinHostPort(p.host, p.port) +} + +func (p parsedAddr) rangedPort() bool { + return strings.Contains(p.port, "-") +} + +func (p parsedAddr) replaceablePort() bool { + return strings.Contains(p.port, "{") && strings.Contains(p.port, "}") +} + +func (p parsedAddr) isUnix() bool { + return caddy.IsUnixNetwork(p.network) +} + // parseUpstreamDialAddress parses configuration inputs for // the dial address, including support for a scheme in front // as a shortcut for the port number, and a network type, // for example 'unix' to dial a unix socket. -func parseUpstreamDialAddress(upstreamAddr string) (string, string, error) { +func parseUpstreamDialAddress(upstreamAddr string) (parsedAddr, error) { var network, scheme, host, port string if strings.Contains(upstreamAddr, "://") { @@ -35,46 +70,65 @@ func parseUpstreamDialAddress(upstreamAddr string) (string, string, error) { // so we return a more user-friendly error message instead // to explain what to do instead if strings.Contains(upstreamAddr, "{") { - return "", "", fmt.Errorf("due to parsing difficulties, placeholders are not allowed when an upstream address contains a scheme") + return parsedAddr{}, fmt.Errorf("due to parsing difficulties, placeholders are not allowed when an upstream address contains a scheme") } toURL, err := url.Parse(upstreamAddr) if err != nil { - return "", "", fmt.Errorf("parsing upstream URL: %v", err) + // if the error seems to be due to a port range, + // try to replace the port range with a dummy + // single port so that url.Parse() will succeed + if strings.Contains(err.Error(), "invalid port") && strings.Contains(err.Error(), "-") { + index := strings.LastIndex(upstreamAddr, ":") + if index == -1 { + return parsedAddr{}, fmt.Errorf("parsing upstream URL: %v", err) + } + portRange := upstreamAddr[index+1:] + if strings.Count(portRange, "-") != 1 { + return parsedAddr{}, fmt.Errorf("parsing upstream URL: parse \"%v\": port range invalid: %v", upstreamAddr, portRange) + } + toURL, err = url.Parse(strings.ReplaceAll(upstreamAddr, portRange, "0")) + if err != nil { + return parsedAddr{}, fmt.Errorf("parsing upstream URL: %v", err) + } + port = portRange + } else { + return parsedAddr{}, fmt.Errorf("parsing upstream URL: %v", err) + } + } + if port == "" { + port = toURL.Port() } // there is currently no way to perform a URL rewrite between choosing // a backend and proxying to it, so we cannot allow extra components // in backend URLs if toURL.Path != "" || toURL.RawQuery != "" || toURL.Fragment != "" { - return "", "", fmt.Errorf("for now, URLs for proxy upstreams only support scheme, host, and port components") + return parsedAddr{}, fmt.Errorf("for now, URLs for proxy upstreams only support scheme, host, and port components") } // ensure the port and scheme aren't in conflict - urlPort := toURL.Port() - if toURL.Scheme == "http" && urlPort == "443" { - return "", "", fmt.Errorf("upstream address has conflicting scheme (http://) and port (:443, the HTTPS port)") + if toURL.Scheme == "http" && port == "443" { + return parsedAddr{}, fmt.Errorf("upstream address has conflicting scheme (http://) and port (:443, the HTTPS port)") } - if toURL.Scheme == "https" && urlPort == "80" { - return "", "", fmt.Errorf("upstream address has conflicting scheme (https://) and port (:80, the HTTP port)") + if toURL.Scheme == "https" && port == "80" { + return parsedAddr{}, fmt.Errorf("upstream address has conflicting scheme (https://) and port (:80, the HTTP port)") } - if toURL.Scheme == "h2c" && urlPort == "443" { - return "", "", fmt.Errorf("upstream address has conflicting scheme (h2c://) and port (:443, the HTTPS port)") + if toURL.Scheme == "h2c" && port == "443" { + return parsedAddr{}, fmt.Errorf("upstream address has conflicting scheme (h2c://) and port (:443, the HTTPS port)") } // if port is missing, attempt to infer from scheme - if toURL.Port() == "" { - var toPort string + if port == "" { switch toURL.Scheme { case "", "http", "h2c": - toPort = "80" + port = "80" case "https": - toPort = "443" + port = "443" } - toURL.Host = net.JoinHostPort(toURL.Hostname(), toPort) } - scheme, host, port = toURL.Scheme, toURL.Hostname(), toURL.Port() + scheme, host = toURL.Scheme, toURL.Hostname() } else { var err error network, host, port, err = caddy.SplitNetworkAddress(upstreamAddr) @@ -93,18 +147,5 @@ func parseUpstreamDialAddress(upstreamAddr string) (string, string, error) { network = "unix" scheme = "h2c" } - - // for simplest possible config, we only need to include - // the network portion if the user specified one - if network != "" { - return caddy.JoinNetworkAddress(network, host, port), scheme, nil - } - - // if the host is a placeholder, then we don't want to join with an empty port, - // because that would just append an extra ':' at the end of the address. - if port == "" && strings.Contains(host, "{") { - return host, scheme, nil - } - - return net.JoinHostPort(host, port), scheme, nil + return parsedAddr{network, scheme, host, port, true}, nil } diff --git a/modules/caddyhttp/reverseproxy/addresses_test.go b/modules/caddyhttp/reverseproxy/addresses_test.go index 6355c75..0c51419 100644 --- a/modules/caddyhttp/reverseproxy/addresses_test.go +++ b/modules/caddyhttp/reverseproxy/addresses_test.go @@ -150,6 +150,24 @@ func TestParseUpstreamDialAddress(t *testing.T) { expectScheme: "h2c", }, { + input: "localhost:1001-1009", + expectHostPort: "localhost:1001-1009", + }, + { + input: "{host}:1001-1009", + expectHostPort: "{host}:1001-1009", + }, + { + input: "http://localhost:1001-1009", + expectHostPort: "localhost:1001-1009", + expectScheme: "http", + }, + { + input: "https://localhost:1001-1009", + expectHostPort: "localhost:1001-1009", + expectScheme: "https", + }, + { input: "unix//var/php.sock", expectHostPort: "unix//var/php.sock", }, @@ -197,6 +215,26 @@ func TestParseUpstreamDialAddress(t *testing.T) { expectErr: true, }, { + input: "http://localhost:8001-8002-8003", + expectErr: true, + }, + { + input: "http://localhost:8001-8002/foo:bar", + expectErr: true, + }, + { + input: "http://localhost:8001-8002/foo:1", + expectErr: true, + }, + { + input: "http://localhost:8001-8002/foo:1-2", + expectErr: true, + }, + { + input: "http://localhost:8001-8002#foo:1", + expectErr: true, + }, + { input: "http://foo:443", expectErr: true, }, @@ -227,18 +265,18 @@ func TestParseUpstreamDialAddress(t *testing.T) { expectScheme: "h2c", }, } { - actualHostPort, actualScheme, err := parseUpstreamDialAddress(tc.input) + actualAddr, err := parseUpstreamDialAddress(tc.input) if tc.expectErr && err == nil { t.Errorf("Test %d: Expected error but got %v", i, err) } if !tc.expectErr && err != nil { t.Errorf("Test %d: Expected no error but got %v", i, err) } - if actualHostPort != tc.expectHostPort { - t.Errorf("Test %d: Expected host and port '%s' but got '%s'", i, tc.expectHostPort, actualHostPort) + if actualAddr.dialAddr() != tc.expectHostPort { + t.Errorf("Test %d: input %s: Expected host and port '%s' but got '%s'", i, tc.input, tc.expectHostPort, actualAddr.dialAddr()) } - if actualScheme != tc.expectScheme { - t.Errorf("Test %d: Expected scheme '%s' but got '%s'", i, tc.expectScheme, actualScheme) + if actualAddr.scheme != tc.expectScheme { + t.Errorf("Test %d: Expected scheme '%s' but got '%s'", i, tc.expectScheme, actualAddr.scheme) } } } diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go index 1211188..bcbe744 100644 --- a/modules/caddyhttp/reverseproxy/caddyfile.go +++ b/modules/caddyhttp/reverseproxy/caddyfile.go @@ -15,12 +15,14 @@ package reverseproxy import ( - "net" + "fmt" "net/http" "reflect" "strconv" "strings" + "github.com/dustin/go-humanize" + "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" @@ -28,7 +30,6 @@ import ( "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/caddyserver/caddy/v2/modules/caddyhttp/headers" "github.com/caddyserver/caddy/v2/modules/caddyhttp/rewrite" - "github.com/dustin/go-humanize" ) func init() { @@ -83,10 +84,13 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) // unhealthy_request_count <num> // // # streaming -// flush_interval <duration> +// flush_interval <duration> // buffer_requests // buffer_responses -// max_buffer_size <size> +// max_buffer_size <size> +// stream_timeout <duration> +// stream_close_delay <duration> +// trace_logs // // # request manipulation // trusted_proxies [private_ranges] <ranges...> @@ -142,16 +146,9 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { h.responseMatchers = make(map[string]caddyhttp.ResponseMatcher) // appendUpstream creates an upstream for address and adds - // it to the list. If the address starts with "srv+" it is - // treated as a SRV-based upstream, and any port will be - // dropped. + // it to the list. appendUpstream := func(address string) error { - isSRV := strings.HasPrefix(address, "srv+") - if isSRV { - address = strings.TrimPrefix(address, "srv+") - } - - dialAddr, scheme, err := parseUpstreamDialAddress(address) + pa, err := parseUpstreamDialAddress(address) if err != nil { return d.WrapErr(err) } @@ -159,573 +156,641 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { // the underlying JSON does not yet support different // transports (protocols or schemes) to each backend, // so we remember the last one we see and compare them - if commonScheme != "" && scheme != commonScheme { + + switch pa.scheme { + case "wss": + return d.Errf("the scheme wss:// is only supported in browsers; use https:// instead") + case "ws": + return d.Errf("the scheme ws:// is only supported in browsers; use http:// instead") + case "https", "http", "h2c", "": + // Do nothing or handle the valid schemes + default: + return d.Errf("unsupported URL scheme %s://", pa.scheme) + } + + if commonScheme != "" && pa.scheme != commonScheme { return d.Errf("for now, all proxy upstreams must use the same scheme (transport protocol); expecting '%s://' but got '%s://'", - commonScheme, scheme) + commonScheme, pa.scheme) } - commonScheme = scheme + commonScheme = pa.scheme - if isSRV { - if host, _, err := net.SplitHostPort(dialAddr); err == nil { - dialAddr = host - } - h.Upstreams = append(h.Upstreams, &Upstream{LookupSRV: dialAddr}) + // if the port of upstream address contains a placeholder, only wrap it with the `Upstream` struct, + // delaying actual resolution of the address until request time. + if pa.replaceablePort() { + h.Upstreams = append(h.Upstreams, &Upstream{Dial: pa.dialAddr()}) + return nil + } + parsedAddr, err := caddy.ParseNetworkAddress(pa.dialAddr()) + if err != nil { + return d.WrapErr(err) + } + + if pa.isUnix() || !pa.rangedPort() { + // unix networks don't have ports + h.Upstreams = append(h.Upstreams, &Upstream{ + Dial: pa.dialAddr(), + }) } else { - h.Upstreams = append(h.Upstreams, &Upstream{Dial: dialAddr}) + // expand a port range into multiple upstreams + for i := parsedAddr.StartPort; i <= parsedAddr.EndPort; i++ { + h.Upstreams = append(h.Upstreams, &Upstream{ + Dial: caddy.JoinNetworkAddress("", parsedAddr.Host, fmt.Sprint(i)), + }) + } } + return nil } - for d.Next() { - for _, up := range d.RemainingArgs() { - err := appendUpstream(up) + d.Next() // consume the directive name + for _, up := range d.RemainingArgs() { + err := appendUpstream(up) + if err != nil { + return fmt.Errorf("parsing upstream '%s': %w", up, err) + } + } + + for nesting := d.Nesting(); d.NextBlock(nesting); { + // if the subdirective has an "@" prefix then we + // parse it as a response matcher for use with "handle_response" + if strings.HasPrefix(d.Val(), matcherPrefix) { + err := caddyhttp.ParseNamedResponseMatcher(d.NewFromNextSegment(), h.responseMatchers) if err != nil { return err } + continue } - for d.NextBlock(0) { - // if the subdirective has an "@" prefix then we - // parse it as a response matcher for use with "handle_response" - if strings.HasPrefix(d.Val(), matcherPrefix) { - err := caddyhttp.ParseNamedResponseMatcher(d.NewFromNextSegment(), h.responseMatchers) + switch d.Val() { + case "to": + args := d.RemainingArgs() + if len(args) == 0 { + return d.ArgErr() + } + for _, up := range args { + err := appendUpstream(up) if err != nil { - return err + return fmt.Errorf("parsing upstream '%s': %w", up, err) } - continue } - switch d.Val() { - case "to": - args := d.RemainingArgs() - if len(args) == 0 { - return d.ArgErr() - } - for _, up := range args { - err := appendUpstream(up) - if err != nil { - return err - } - } + case "dynamic": + if !d.NextArg() { + return d.ArgErr() + } + if h.DynamicUpstreams != nil { + return d.Err("dynamic upstreams already specified") + } + dynModule := d.Val() + modID := "http.reverse_proxy.upstreams." + dynModule + unm, err := caddyfile.UnmarshalModule(d, modID) + if err != nil { + return err + } + source, ok := unm.(UpstreamSource) + if !ok { + return d.Errf("module %s (%T) is not an UpstreamSource", modID, unm) + } + h.DynamicUpstreamsRaw = caddyconfig.JSONModuleObject(source, "source", dynModule, nil) - case "dynamic": - if !d.NextArg() { - return d.ArgErr() - } - if h.DynamicUpstreams != nil { - return d.Err("dynamic upstreams already specified") - } - dynModule := d.Val() - modID := "http.reverse_proxy.upstreams." + dynModule - unm, err := caddyfile.UnmarshalModule(d, modID) - if err != nil { - return err - } - source, ok := unm.(UpstreamSource) - if !ok { - return d.Errf("module %s (%T) is not an UpstreamSource", modID, unm) - } - h.DynamicUpstreamsRaw = caddyconfig.JSONModuleObject(source, "source", dynModule, nil) + case "lb_policy": + if !d.NextArg() { + return d.ArgErr() + } + if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil { + return d.Err("load balancing selection policy already specified") + } + name := d.Val() + modID := "http.reverse_proxy.selection_policies." + name + unm, err := caddyfile.UnmarshalModule(d, modID) + if err != nil { + return err + } + sel, ok := unm.(Selector) + if !ok { + return d.Errf("module %s (%T) is not a reverseproxy.Selector", modID, unm) + } + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + h.LoadBalancing.SelectionPolicyRaw = caddyconfig.JSONModuleObject(sel, "policy", name, nil) - case "lb_policy": - if !d.NextArg() { - return d.ArgErr() - } - if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil { - return d.Err("load balancing selection policy already specified") - } - name := d.Val() - modID := "http.reverse_proxy.selection_policies." + name - unm, err := caddyfile.UnmarshalModule(d, modID) - if err != nil { - return err - } - sel, ok := unm.(Selector) - if !ok { - return d.Errf("module %s (%T) is not a reverseproxy.Selector", modID, unm) - } - if h.LoadBalancing == nil { - h.LoadBalancing = new(LoadBalancing) - } - h.LoadBalancing.SelectionPolicyRaw = caddyconfig.JSONModuleObject(sel, "policy", name, nil) + case "lb_retries": + if !d.NextArg() { + return d.ArgErr() + } + tries, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("bad lb_retries number '%s': %v", d.Val(), err) + } + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + h.LoadBalancing.Retries = tries - case "lb_retries": - if !d.NextArg() { - return d.ArgErr() - } - tries, err := strconv.Atoi(d.Val()) - if err != nil { - return d.Errf("bad lb_retries number '%s': %v", d.Val(), err) - } - if h.LoadBalancing == nil { - h.LoadBalancing = new(LoadBalancing) - } - h.LoadBalancing.Retries = tries + case "lb_try_duration": + if !d.NextArg() { + return d.ArgErr() + } + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad duration value %s: %v", d.Val(), err) + } + h.LoadBalancing.TryDuration = caddy.Duration(dur) - case "lb_try_duration": - if !d.NextArg() { - return d.ArgErr() - } - if h.LoadBalancing == nil { - h.LoadBalancing = new(LoadBalancing) - } - dur, err := caddy.ParseDuration(d.Val()) - if err != nil { - return d.Errf("bad duration value %s: %v", d.Val(), err) - } - h.LoadBalancing.TryDuration = caddy.Duration(dur) + case "lb_try_interval": + if !d.NextArg() { + return d.ArgErr() + } + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad interval value '%s': %v", d.Val(), err) + } + h.LoadBalancing.TryInterval = caddy.Duration(dur) - case "lb_try_interval": - if !d.NextArg() { - return d.ArgErr() - } - if h.LoadBalancing == nil { - h.LoadBalancing = new(LoadBalancing) - } - dur, err := caddy.ParseDuration(d.Val()) - if err != nil { - return d.Errf("bad interval value '%s': %v", d.Val(), err) - } - h.LoadBalancing.TryInterval = caddy.Duration(dur) + case "lb_retry_match": + matcherSet, err := caddyhttp.ParseCaddyfileNestedMatcherSet(d) + if err != nil { + return d.Errf("failed to parse lb_retry_match: %v", err) + } + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + h.LoadBalancing.RetryMatchRaw = append(h.LoadBalancing.RetryMatchRaw, matcherSet) - case "lb_retry_match": - matcherSet, err := caddyhttp.ParseCaddyfileNestedMatcherSet(d) - if err != nil { - return d.Errf("failed to parse lb_retry_match: %v", err) - } - if h.LoadBalancing == nil { - h.LoadBalancing = new(LoadBalancing) - } - h.LoadBalancing.RetryMatchRaw = append(h.LoadBalancing.RetryMatchRaw, matcherSet) + case "health_uri": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + h.HealthChecks.Active.URI = d.Val() - case "health_uri": - if !d.NextArg() { - return d.ArgErr() - } - if h.HealthChecks == nil { - h.HealthChecks = new(HealthChecks) - } - if h.HealthChecks.Active == nil { - h.HealthChecks.Active = new(ActiveHealthChecks) - } - h.HealthChecks.Active.URI = d.Val() + case "health_path": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + h.HealthChecks.Active.Path = d.Val() + caddy.Log().Named("config.adapter.caddyfile").Warn("the 'health_path' subdirective is deprecated, please use 'health_uri' instead!") - case "health_path": - if !d.NextArg() { - return d.ArgErr() - } - if h.HealthChecks == nil { - h.HealthChecks = new(HealthChecks) - } - if h.HealthChecks.Active == nil { - h.HealthChecks.Active = new(ActiveHealthChecks) - } - h.HealthChecks.Active.Path = d.Val() - caddy.Log().Named("config.adapter.caddyfile").Warn("the 'health_path' subdirective is deprecated, please use 'health_uri' instead!") + case "health_port": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + portNum, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("bad port number '%s': %v", d.Val(), err) + } + h.HealthChecks.Active.Port = portNum - case "health_port": - if !d.NextArg() { - return d.ArgErr() - } - if h.HealthChecks == nil { - h.HealthChecks = new(HealthChecks) - } - if h.HealthChecks.Active == nil { - h.HealthChecks.Active = new(ActiveHealthChecks) + case "health_headers": + healthHeaders := make(http.Header) + for nesting := d.Nesting(); d.NextBlock(nesting); { + key := d.Val() + values := d.RemainingArgs() + if len(values) == 0 { + values = append(values, "") } - portNum, err := strconv.Atoi(d.Val()) - if err != nil { - return d.Errf("bad port number '%s': %v", d.Val(), err) - } - h.HealthChecks.Active.Port = portNum + healthHeaders[key] = append(healthHeaders[key], values...) + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + h.HealthChecks.Active.Headers = healthHeaders - case "health_headers": - healthHeaders := make(http.Header) - for nesting := d.Nesting(); d.NextBlock(nesting); { - key := d.Val() - values := d.RemainingArgs() - if len(values) == 0 { - values = append(values, "") - } - healthHeaders[key] = values - } - if h.HealthChecks == nil { - h.HealthChecks = new(HealthChecks) - } - if h.HealthChecks.Active == nil { - h.HealthChecks.Active = new(ActiveHealthChecks) - } - h.HealthChecks.Active.Headers = healthHeaders + case "health_interval": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad interval value %s: %v", d.Val(), err) + } + h.HealthChecks.Active.Interval = caddy.Duration(dur) - case "health_interval": - if !d.NextArg() { - return d.ArgErr() - } - if h.HealthChecks == nil { - h.HealthChecks = new(HealthChecks) - } - if h.HealthChecks.Active == nil { - h.HealthChecks.Active = new(ActiveHealthChecks) - } - dur, err := caddy.ParseDuration(d.Val()) - if err != nil { - return d.Errf("bad interval value %s: %v", d.Val(), err) - } - h.HealthChecks.Active.Interval = caddy.Duration(dur) + case "health_timeout": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad timeout value %s: %v", d.Val(), err) + } + h.HealthChecks.Active.Timeout = caddy.Duration(dur) - case "health_timeout": - if !d.NextArg() { - return d.ArgErr() - } - if h.HealthChecks == nil { - h.HealthChecks = new(HealthChecks) - } - if h.HealthChecks.Active == nil { - h.HealthChecks.Active = new(ActiveHealthChecks) - } - dur, err := caddy.ParseDuration(d.Val()) - if err != nil { - return d.Errf("bad timeout value %s: %v", d.Val(), err) - } - h.HealthChecks.Active.Timeout = caddy.Duration(dur) + case "health_status": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + val := d.Val() + if len(val) == 3 && strings.HasSuffix(val, "xx") { + val = val[:1] + } + statusNum, err := strconv.Atoi(val) + if err != nil { + return d.Errf("bad status value '%s': %v", d.Val(), err) + } + h.HealthChecks.Active.ExpectStatus = statusNum - case "health_status": - if !d.NextArg() { - return d.ArgErr() - } - if h.HealthChecks == nil { - h.HealthChecks = new(HealthChecks) - } - if h.HealthChecks.Active == nil { - h.HealthChecks.Active = new(ActiveHealthChecks) - } - val := d.Val() - if len(val) == 3 && strings.HasSuffix(val, "xx") { - val = val[:1] - } - statusNum, err := strconv.Atoi(val) - if err != nil { - return d.Errf("bad status value '%s': %v", d.Val(), err) - } - h.HealthChecks.Active.ExpectStatus = statusNum + case "health_body": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + h.HealthChecks.Active.ExpectBody = d.Val() - case "health_body": - if !d.NextArg() { - return d.ArgErr() - } - if h.HealthChecks == nil { - h.HealthChecks = new(HealthChecks) - } - if h.HealthChecks.Active == nil { - h.HealthChecks.Active = new(ActiveHealthChecks) - } - h.HealthChecks.Active.ExpectBody = d.Val() + case "max_fails": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + maxFails, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("invalid maximum fail count '%s': %v", d.Val(), err) + } + h.HealthChecks.Passive.MaxFails = maxFails - case "max_fails": - if !d.NextArg() { - return d.ArgErr() - } - if h.HealthChecks == nil { - h.HealthChecks = new(HealthChecks) - } - if h.HealthChecks.Passive == nil { - h.HealthChecks.Passive = new(PassiveHealthChecks) + case "fail_duration": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad duration value '%s': %v", d.Val(), err) + } + h.HealthChecks.Passive.FailDuration = caddy.Duration(dur) + + case "unhealthy_request_count": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + maxConns, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("invalid maximum connection count '%s': %v", d.Val(), err) + } + h.HealthChecks.Passive.UnhealthyRequestCount = maxConns + + case "unhealthy_status": + args := d.RemainingArgs() + if len(args) == 0 { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + for _, arg := range args { + if len(arg) == 3 && strings.HasSuffix(arg, "xx") { + arg = arg[:1] } - maxFails, err := strconv.Atoi(d.Val()) + statusNum, err := strconv.Atoi(arg) if err != nil { - return d.Errf("invalid maximum fail count '%s': %v", d.Val(), err) + return d.Errf("bad status value '%s': %v", d.Val(), err) } - h.HealthChecks.Passive.MaxFails = maxFails + h.HealthChecks.Passive.UnhealthyStatus = append(h.HealthChecks.Passive.UnhealthyStatus, statusNum) + } - case "fail_duration": - if !d.NextArg() { - return d.ArgErr() - } - if h.HealthChecks == nil { - h.HealthChecks = new(HealthChecks) - } - if h.HealthChecks.Passive == nil { - h.HealthChecks.Passive = new(PassiveHealthChecks) - } + case "unhealthy_latency": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad duration value '%s': %v", d.Val(), err) + } + h.HealthChecks.Passive.UnhealthyLatency = caddy.Duration(dur) + + case "flush_interval": + if !d.NextArg() { + return d.ArgErr() + } + if fi, err := strconv.Atoi(d.Val()); err == nil { + h.FlushInterval = caddy.Duration(fi) + } else { dur, err := caddy.ParseDuration(d.Val()) if err != nil { return d.Errf("bad duration value '%s': %v", d.Val(), err) } - h.HealthChecks.Passive.FailDuration = caddy.Duration(dur) + h.FlushInterval = caddy.Duration(dur) + } - case "unhealthy_request_count": - if !d.NextArg() { - return d.ArgErr() - } - if h.HealthChecks == nil { - h.HealthChecks = new(HealthChecks) - } - if h.HealthChecks.Passive == nil { - h.HealthChecks.Passive = new(PassiveHealthChecks) - } - maxConns, err := strconv.Atoi(d.Val()) + case "request_buffers", "response_buffers": + subdir := d.Val() + if !d.NextArg() { + return d.ArgErr() + } + val := d.Val() + var size int64 + if val == "unlimited" { + size = -1 + } else { + usize, err := humanize.ParseBytes(val) if err != nil { - return d.Errf("invalid maximum connection count '%s': %v", d.Val(), err) + return d.Errf("invalid byte size '%s': %v", val, err) } - h.HealthChecks.Passive.UnhealthyRequestCount = maxConns + size = int64(usize) + } + if d.NextArg() { + return d.ArgErr() + } + if subdir == "request_buffers" { + h.RequestBuffers = size + } else if subdir == "response_buffers" { + h.ResponseBuffers = size + } - case "unhealthy_status": - args := d.RemainingArgs() - if len(args) == 0 { - return d.ArgErr() - } - if h.HealthChecks == nil { - h.HealthChecks = new(HealthChecks) - } - if h.HealthChecks.Passive == nil { - h.HealthChecks.Passive = new(PassiveHealthChecks) - } - for _, arg := range args { - if len(arg) == 3 && strings.HasSuffix(arg, "xx") { - arg = arg[:1] - } - statusNum, err := strconv.Atoi(arg) - if err != nil { - return d.Errf("bad status value '%s': %v", d.Val(), err) - } - h.HealthChecks.Passive.UnhealthyStatus = append(h.HealthChecks.Passive.UnhealthyStatus, statusNum) - } + // TODO: These three properties are deprecated; remove them sometime after v2.6.4 + case "buffer_requests": // TODO: deprecated + if d.NextArg() { + return d.ArgErr() + } + caddy.Log().Named("config.adapter.caddyfile").Warn("DEPRECATED: buffer_requests: use request_buffers instead (with a maximum buffer size)") + h.DeprecatedBufferRequests = true + case "buffer_responses": // TODO: deprecated + if d.NextArg() { + return d.ArgErr() + } + caddy.Log().Named("config.adapter.caddyfile").Warn("DEPRECATED: buffer_responses: use response_buffers instead (with a maximum buffer size)") + h.DeprecatedBufferResponses = true + case "max_buffer_size": // TODO: deprecated + if !d.NextArg() { + return d.ArgErr() + } + size, err := humanize.ParseBytes(d.Val()) + if err != nil { + return d.Errf("invalid byte size '%s': %v", d.Val(), err) + } + if d.NextArg() { + return d.ArgErr() + } + caddy.Log().Named("config.adapter.caddyfile").Warn("DEPRECATED: max_buffer_size: use request_buffers and/or response_buffers instead (with maximum buffer sizes)") + h.DeprecatedMaxBufferSize = int64(size) - case "unhealthy_latency": - if !d.NextArg() { - return d.ArgErr() - } - if h.HealthChecks == nil { - h.HealthChecks = new(HealthChecks) - } - if h.HealthChecks.Passive == nil { - h.HealthChecks.Passive = new(PassiveHealthChecks) - } + case "stream_timeout": + if !d.NextArg() { + return d.ArgErr() + } + if fi, err := strconv.Atoi(d.Val()); err == nil { + h.StreamTimeout = caddy.Duration(fi) + } else { dur, err := caddy.ParseDuration(d.Val()) if err != nil { return d.Errf("bad duration value '%s': %v", d.Val(), err) } - h.HealthChecks.Passive.UnhealthyLatency = caddy.Duration(dur) - - case "flush_interval": - if !d.NextArg() { - return d.ArgErr() - } - if fi, err := strconv.Atoi(d.Val()); err == nil { - h.FlushInterval = caddy.Duration(fi) - } else { - dur, err := caddy.ParseDuration(d.Val()) - if err != nil { - return d.Errf("bad duration value '%s': %v", d.Val(), err) - } - h.FlushInterval = caddy.Duration(dur) - } + h.StreamTimeout = caddy.Duration(dur) + } - case "request_buffers", "response_buffers": - subdir := d.Val() - if !d.NextArg() { - return d.ArgErr() - } - size, err := humanize.ParseBytes(d.Val()) + case "stream_close_delay": + if !d.NextArg() { + return d.ArgErr() + } + if fi, err := strconv.Atoi(d.Val()); err == nil { + h.StreamCloseDelay = caddy.Duration(fi) + } else { + dur, err := caddy.ParseDuration(d.Val()) if err != nil { - return d.Errf("invalid byte size '%s': %v", d.Val(), err) - } - if d.NextArg() { - return d.ArgErr() - } - if subdir == "request_buffers" { - h.RequestBuffers = int64(size) - } else if subdir == "response_buffers" { - h.ResponseBuffers = int64(size) - + return d.Errf("bad duration value '%s': %v", d.Val(), err) } + h.StreamCloseDelay = caddy.Duration(dur) + } - // TODO: These three properties are deprecated; remove them sometime after v2.6.4 - case "buffer_requests": // TODO: deprecated - if d.NextArg() { - return d.ArgErr() + case "trusted_proxies": + for d.NextArg() { + if d.Val() == "private_ranges" { + h.TrustedProxies = append(h.TrustedProxies, caddyhttp.PrivateRangesCIDR()...) + continue } - caddy.Log().Named("config.adapter.caddyfile").Warn("DEPRECATED: buffer_requests: use request_buffers instead (with a maximum buffer size)") - h.DeprecatedBufferRequests = true - case "buffer_responses": // TODO: deprecated - if d.NextArg() { - return d.ArgErr() - } - caddy.Log().Named("config.adapter.caddyfile").Warn("DEPRECATED: buffer_responses: use response_buffers instead (with a maximum buffer size)") - h.DeprecatedBufferResponses = true - case "max_buffer_size": // TODO: deprecated - if !d.NextArg() { - return d.ArgErr() - } - size, err := humanize.ParseBytes(d.Val()) - if err != nil { - return d.Errf("invalid byte size '%s': %v", d.Val(), err) - } - if d.NextArg() { - return d.ArgErr() - } - caddy.Log().Named("config.adapter.caddyfile").Warn("DEPRECATED: max_buffer_size: use request_buffers and/or response_buffers instead (with maximum buffer sizes)") - h.DeprecatedMaxBufferSize = int64(size) + h.TrustedProxies = append(h.TrustedProxies, d.Val()) + } - case "trusted_proxies": - for d.NextArg() { - if d.Val() == "private_ranges" { - h.TrustedProxies = append(h.TrustedProxies, caddyhttp.PrivateRangesCIDR()...) - continue - } - h.TrustedProxies = append(h.TrustedProxies, d.Val()) - } + case "header_up": + var err error - case "header_up": - var err error + if h.Headers == nil { + h.Headers = new(headers.Handler) + } + if h.Headers.Request == nil { + h.Headers.Request = new(headers.HeaderOps) + } + args := d.RemainingArgs() - if h.Headers == nil { - h.Headers = new(headers.Handler) + switch len(args) { + case 1: + err = headers.CaddyfileHeaderOp(h.Headers.Request, args[0], "", "") + case 2: + // some lint checks, I guess + if strings.EqualFold(args[0], "host") && (args[1] == "{hostport}" || args[1] == "{http.request.hostport}") { + caddy.Log().Named("caddyfile").Warn("Unnecessary header_up Host: the reverse proxy's default behavior is to pass headers to the upstream") } - if h.Headers.Request == nil { - h.Headers.Request = new(headers.HeaderOps) + if strings.EqualFold(args[0], "x-forwarded-for") && (args[1] == "{remote}" || args[1] == "{http.request.remote}" || args[1] == "{remote_host}" || args[1] == "{http.request.remote.host}") { + caddy.Log().Named("caddyfile").Warn("Unnecessary header_up X-Forwarded-For: the reverse proxy's default behavior is to pass headers to the upstream") } - args := d.RemainingArgs() - - switch len(args) { - case 1: - err = headers.CaddyfileHeaderOp(h.Headers.Request, args[0], "", "") - case 2: - // some lint checks, I guess - if strings.EqualFold(args[0], "host") && (args[1] == "{hostport}" || args[1] == "{http.request.hostport}") { - caddy.Log().Named("caddyfile").Warn("Unnecessary header_up Host: the reverse proxy's default behavior is to pass headers to the upstream") - } - if strings.EqualFold(args[0], "x-forwarded-for") && (args[1] == "{remote}" || args[1] == "{http.request.remote}" || args[1] == "{remote_host}" || args[1] == "{http.request.remote.host}") { - caddy.Log().Named("caddyfile").Warn("Unnecessary header_up X-Forwarded-For: the reverse proxy's default behavior is to pass headers to the upstream") - } - if strings.EqualFold(args[0], "x-forwarded-proto") && (args[1] == "{scheme}" || args[1] == "{http.request.scheme}") { - caddy.Log().Named("caddyfile").Warn("Unnecessary header_up X-Forwarded-Proto: the reverse proxy's default behavior is to pass headers to the upstream") - } - if strings.EqualFold(args[0], "x-forwarded-host") && (args[1] == "{host}" || args[1] == "{http.request.host}" || args[1] == "{hostport}" || args[1] == "{http.request.hostport}") { - caddy.Log().Named("caddyfile").Warn("Unnecessary header_up X-Forwarded-Host: the reverse proxy's default behavior is to pass headers to the upstream") - } - err = headers.CaddyfileHeaderOp(h.Headers.Request, args[0], args[1], "") - case 3: - err = headers.CaddyfileHeaderOp(h.Headers.Request, args[0], args[1], args[2]) - default: - return d.ArgErr() + if strings.EqualFold(args[0], "x-forwarded-proto") && (args[1] == "{scheme}" || args[1] == "{http.request.scheme}") { + caddy.Log().Named("caddyfile").Warn("Unnecessary header_up X-Forwarded-Proto: the reverse proxy's default behavior is to pass headers to the upstream") } - - if err != nil { - return d.Err(err.Error()) + if strings.EqualFold(args[0], "x-forwarded-host") && (args[1] == "{host}" || args[1] == "{http.request.host}" || args[1] == "{hostport}" || args[1] == "{http.request.hostport}") { + caddy.Log().Named("caddyfile").Warn("Unnecessary header_up X-Forwarded-Host: the reverse proxy's default behavior is to pass headers to the upstream") } + err = headers.CaddyfileHeaderOp(h.Headers.Request, args[0], args[1], "") + case 3: + err = headers.CaddyfileHeaderOp(h.Headers.Request, args[0], args[1], args[2]) + default: + return d.ArgErr() + } - case "header_down": - var err error + if err != nil { + return d.Err(err.Error()) + } - if h.Headers == nil { - h.Headers = new(headers.Handler) - } - if h.Headers.Response == nil { - h.Headers.Response = &headers.RespHeaderOps{ - HeaderOps: new(headers.HeaderOps), - } - } - args := d.RemainingArgs() - switch len(args) { - case 1: - err = headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], "", "") - case 2: - err = headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], args[1], "") - case 3: - err = headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], args[1], args[2]) - default: - return d.ArgErr() - } + case "header_down": + var err error - if err != nil { - return d.Err(err.Error()) - } - - case "method": - if !d.NextArg() { - return d.ArgErr() - } - if h.Rewrite == nil { - h.Rewrite = &rewrite.Rewrite{} - } - h.Rewrite.Method = d.Val() - if d.NextArg() { - return d.ArgErr() + if h.Headers == nil { + h.Headers = new(headers.Handler) + } + if h.Headers.Response == nil { + h.Headers.Response = &headers.RespHeaderOps{ + HeaderOps: new(headers.HeaderOps), } + } + args := d.RemainingArgs() + switch len(args) { + case 1: + err = headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], "", "") + case 2: + err = headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], args[1], "") + case 3: + err = headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], args[1], args[2]) + default: + return d.ArgErr() + } - case "rewrite": - if !d.NextArg() { - return d.ArgErr() - } - if h.Rewrite == nil { - h.Rewrite = &rewrite.Rewrite{} - } - h.Rewrite.URI = d.Val() - if d.NextArg() { - return d.ArgErr() - } + if err != nil { + return d.Err(err.Error()) + } - case "transport": - if !d.NextArg() { - return d.ArgErr() - } - if h.TransportRaw != nil { - return d.Err("transport already specified") - } - transportModuleName = d.Val() - modID := "http.reverse_proxy.transport." + transportModuleName - unm, err := caddyfile.UnmarshalModule(d, modID) - if err != nil { - return err - } - rt, ok := unm.(http.RoundTripper) - if !ok { - return d.Errf("module %s (%T) is not a RoundTripper", modID, unm) - } - transport = rt + case "method": + if !d.NextArg() { + return d.ArgErr() + } + if h.Rewrite == nil { + h.Rewrite = &rewrite.Rewrite{} + } + h.Rewrite.Method = d.Val() + if d.NextArg() { + return d.ArgErr() + } - case "handle_response": - // delegate the parsing of handle_response to the caller, - // since we need the httpcaddyfile.Helper to parse subroutes. - // See h.FinalizeUnmarshalCaddyfile - h.handleResponseSegments = append(h.handleResponseSegments, d.NewFromNextSegment()) + case "rewrite": + if !d.NextArg() { + return d.ArgErr() + } + if h.Rewrite == nil { + h.Rewrite = &rewrite.Rewrite{} + } + h.Rewrite.URI = d.Val() + if d.NextArg() { + return d.ArgErr() + } - case "replace_status": - args := d.RemainingArgs() - if len(args) != 1 && len(args) != 2 { - return d.Errf("must have one or two arguments: an optional response matcher, and a status code") - } + case "transport": + if !d.NextArg() { + return d.ArgErr() + } + if h.TransportRaw != nil { + return d.Err("transport already specified") + } + transportModuleName = d.Val() + modID := "http.reverse_proxy.transport." + transportModuleName + unm, err := caddyfile.UnmarshalModule(d, modID) + if err != nil { + return err + } + rt, ok := unm.(http.RoundTripper) + if !ok { + return d.Errf("module %s (%T) is not a RoundTripper", modID, unm) + } + transport = rt + + case "handle_response": + // delegate the parsing of handle_response to the caller, + // since we need the httpcaddyfile.Helper to parse subroutes. + // See h.FinalizeUnmarshalCaddyfile + h.handleResponseSegments = append(h.handleResponseSegments, d.NewFromNextSegment()) + + case "replace_status": + args := d.RemainingArgs() + if len(args) != 1 && len(args) != 2 { + return d.Errf("must have one or two arguments: an optional response matcher, and a status code") + } - responseHandler := caddyhttp.ResponseHandler{} + responseHandler := caddyhttp.ResponseHandler{} - if len(args) == 2 { - if !strings.HasPrefix(args[0], matcherPrefix) { - return d.Errf("must use a named response matcher, starting with '@'") - } - foundMatcher, ok := h.responseMatchers[args[0]] - if !ok { - return d.Errf("no named response matcher defined with name '%s'", args[0][1:]) - } - responseHandler.Match = &foundMatcher - responseHandler.StatusCode = caddyhttp.WeakString(args[1]) - } else if len(args) == 1 { - responseHandler.StatusCode = caddyhttp.WeakString(args[0]) + if len(args) == 2 { + if !strings.HasPrefix(args[0], matcherPrefix) { + return d.Errf("must use a named response matcher, starting with '@'") } - - // make sure there's no block, cause it doesn't make sense - if d.NextBlock(1) { - return d.Errf("cannot define routes for 'replace_status', use 'handle_response' instead.") + foundMatcher, ok := h.responseMatchers[args[0]] + if !ok { + return d.Errf("no named response matcher defined with name '%s'", args[0][1:]) } + responseHandler.Match = &foundMatcher + responseHandler.StatusCode = caddyhttp.WeakString(args[1]) + } else if len(args) == 1 { + responseHandler.StatusCode = caddyhttp.WeakString(args[0]) + } - h.HandleResponse = append( - h.HandleResponse, - responseHandler, - ) + // make sure there's no block, cause it doesn't make sense + if d.NextBlock(1) { + return d.Errf("cannot define routes for 'replace_status', use 'handle_response' instead.") + } - default: - return d.Errf("unrecognized subdirective %s", d.Val()) + h.HandleResponse = append( + h.HandleResponse, + responseHandler, + ) + + case "verbose_logs": + if h.VerboseLogs { + return d.Err("verbose_logs already specified") } + h.VerboseLogs = true + + default: + return d.Errf("unrecognized subdirective %s", d.Val()) } } @@ -918,6 +983,17 @@ func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } h.MaxResponseHeaderSize = int64(size) + case "proxy_protocol": + if !d.NextArg() { + return d.ArgErr() + } + switch proxyProtocol := d.Val(); proxyProtocol { + case "v1", "v2": + h.ProxyProtocol = proxyProtocol + default: + return d.Errf("invalid proxy protocol version '%s'", proxyProtocol) + } + case "dial_timeout": if !d.NextArg() { return d.ArgErr() @@ -1324,6 +1400,7 @@ func (u *SRVUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { // resolvers <resolvers...> // dial_timeout <timeout> // dial_fallback_delay <timeout> +// versions ipv4|ipv6 // } func (u *AUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { for d.Next() { @@ -1397,8 +1474,30 @@ func (u *AUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } u.FallbackDelay = caddy.Duration(dur) + case "versions": + args := d.RemainingArgs() + if len(args) == 0 { + return d.Errf("must specify at least one version") + } + + if u.Versions == nil { + u.Versions = &IPVersions{} + } + + trueBool := true + for _, arg := range args { + switch arg { + case "ipv4": + u.Versions.IPv4 = &trueBool + case "ipv6": + u.Versions.IPv6 = &trueBool + default: + return d.Errf("unsupported version: '%s'", arg) + } + } + default: - return d.Errf("unrecognized srv option '%s'", d.Val()) + return d.Errf("unrecognized a option '%s'", d.Val()) } } } diff --git a/modules/caddyhttp/reverseproxy/command.go b/modules/caddyhttp/reverseproxy/command.go index 44f4c22..11f935c 100644 --- a/modules/caddyhttp/reverseproxy/command.go +++ b/modules/caddyhttp/reverseproxy/command.go @@ -16,26 +16,28 @@ package reverseproxy import ( "encoding/json" - "flag" "fmt" "net/http" "strconv" + "strings" + + "github.com/spf13/cobra" + "go.uber.org/zap" + + caddycmd "github.com/caddyserver/caddy/v2/cmd" "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig" "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" - caddycmd "github.com/caddyserver/caddy/v2/cmd" "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/caddyserver/caddy/v2/modules/caddyhttp/headers" "github.com/caddyserver/caddy/v2/modules/caddytls" - "go.uber.org/zap" ) func init() { caddycmd.RegisterCommand(caddycmd.Command{ Name: "reverse-proxy", - Func: cmdReverseProxy, - Usage: "[--from <addr>] [--to <addr>] [--change-host-header]", + Usage: `[--from <addr>] [--to <addr>] [--change-host-header] [--insecure] [--internal-certs] [--disable-redirects] [--header-up "Field: value"] [--header-down "Field: value"] [--access-log] [--debug]`, Short: "A quick and production-ready reverse proxy", Long: ` A simple but production-ready reverse proxy. Useful for quick deployments, @@ -52,21 +54,33 @@ If the --from address has a host or IP, Caddy will attempt to serve the proxy over HTTPS with a certificate (unless overridden by the HTTP scheme or port). -If --change-host-header is set, the Host header on the request will be modified -from its original incoming value to the address of the upstream. (Otherwise, by -default, all incoming headers are passed through unmodified.) +If serving HTTPS: + --disable-redirects can be used to avoid binding to the HTTP port. + --internal-certs can be used to force issuance certs using the internal + CA instead of attempting to issue a public certificate. + +For proxying: + --header-up can be used to set a request header to send to the upstream. + --header-down can be used to set a response header to send back to the client. + --change-host-header sets the Host header on the request to the address + of the upstream, instead of defaulting to the incoming Host header. + This is a shortcut for --header-up "Host: {http.reverse_proxy.upstream.hostport}". + --insecure disables TLS verification with the upstream. WARNING: THIS + DISABLES SECURITY BY NOT VERIFYING THE UPSTREAM'S CERTIFICATE. `, - Flags: func() *flag.FlagSet { - fs := flag.NewFlagSet("reverse-proxy", flag.ExitOnError) - fs.String("from", "localhost", "Address on which to receive traffic") - fs.Var(&reverseProxyCmdTo, "to", "Upstream address(es) to which traffic should be sent") - fs.Bool("change-host-header", false, "Set upstream Host header to address of upstream") - fs.Bool("insecure", false, "Disable TLS verification (WARNING: DISABLES SECURITY BY NOT VERIFYING SSL CERTIFICATES!)") - fs.Bool("internal-certs", false, "Use internal CA for issuing certs") - fs.Bool("debug", false, "Enable verbose debug logs") - fs.Bool("disable-redirects", false, "Disable HTTP->HTTPS redirects") - return fs - }(), + CobraFunc: func(cmd *cobra.Command) { + cmd.Flags().StringP("from", "f", "localhost", "Address on which to receive traffic") + cmd.Flags().StringSliceP("to", "t", []string{}, "Upstream address(es) to which traffic should be sent") + cmd.Flags().BoolP("change-host-header", "c", false, "Set upstream Host header to address of upstream") + cmd.Flags().BoolP("insecure", "", false, "Disable TLS verification (WARNING: DISABLES SECURITY BY NOT VERIFYING TLS CERTIFICATES!)") + cmd.Flags().BoolP("disable-redirects", "r", false, "Disable HTTP->HTTPS redirects") + cmd.Flags().BoolP("internal-certs", "i", false, "Use internal CA for issuing certs") + cmd.Flags().StringSliceP("header-up", "H", []string{}, "Set a request header to send to the upstream (format: \"Field: value\")") + cmd.Flags().StringSliceP("header-down", "d", []string{}, "Set a response header to send back to the client (format: \"Field: value\")") + cmd.Flags().BoolP("access-log", "", false, "Enable the access log") + cmd.Flags().BoolP("debug", "v", false, "Enable verbose debug logs") + cmd.RunE = caddycmd.WrapCommandFuncForCobra(cmdReverseProxy) + }, }) } @@ -76,14 +90,19 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) { from := fs.String("from") changeHost := fs.Bool("change-host-header") insecure := fs.Bool("insecure") + disableRedir := fs.Bool("disable-redirects") internalCerts := fs.Bool("internal-certs") + accessLog := fs.Bool("access-log") debug := fs.Bool("debug") - disableRedir := fs.Bool("disable-redirects") httpPort := strconv.Itoa(caddyhttp.DefaultHTTPPort) httpsPort := strconv.Itoa(caddyhttp.DefaultHTTPSPort) - if len(reverseProxyCmdTo) == 0 { + to, err := fs.GetStringSlice("to") + if err != nil { + return caddy.ExitCodeFailedStartup, fmt.Errorf("invalid to flag: %v", err) + } + if len(to) == 0 { return caddy.ExitCodeFailedStartup, fmt.Errorf("--to is required") } @@ -112,17 +131,17 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) { // set up the upstream address; assume missing information from given parts // mixing schemes isn't supported, so use first defined (if available) - toAddresses := make([]string, len(reverseProxyCmdTo)) + toAddresses := make([]string, len(to)) var toScheme string - for i, toLoc := range reverseProxyCmdTo { - addr, scheme, err := parseUpstreamDialAddress(toLoc) + for i, toLoc := range to { + addr, err := parseUpstreamDialAddress(toLoc) if err != nil { return caddy.ExitCodeFailedStartup, fmt.Errorf("invalid upstream address %s: %v", toLoc, err) } - if scheme != "" && toScheme == "" { - toScheme = scheme + if addr.scheme != "" && toScheme == "" { + toScheme = addr.scheme } - toAddresses[i] = addr + toAddresses[i] = addr.dialAddr() } // proceed to build the handler and server @@ -136,9 +155,24 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) { upstreamPool := UpstreamPool{} for _, toAddr := range toAddresses { - upstreamPool = append(upstreamPool, &Upstream{ - Dial: toAddr, - }) + parsedAddr, err := caddy.ParseNetworkAddress(toAddr) + if err != nil { + return caddy.ExitCodeFailedStartup, fmt.Errorf("invalid upstream address %s: %v", toAddr, err) + } + + if parsedAddr.StartPort == 0 && parsedAddr.EndPort == 0 { + // unix networks don't have ports + upstreamPool = append(upstreamPool, &Upstream{ + Dial: toAddr, + }) + } else { + // expand a port range into multiple upstreams + for i := parsedAddr.StartPort; i <= parsedAddr.EndPort; i++ { + upstreamPool = append(upstreamPool, &Upstream{ + Dial: caddy.JoinNetworkAddress("", parsedAddr.Host, fmt.Sprint(i)), + }) + } + } } handler := Handler{ @@ -146,16 +180,64 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) { Upstreams: upstreamPool, } - if changeHost { + // set up header_up + headerUp, err := fs.GetStringSlice("header-up") + if err != nil { + return caddy.ExitCodeFailedStartup, fmt.Errorf("invalid header flag: %v", err) + } + if len(headerUp) > 0 { + reqHdr := make(http.Header) + for i, h := range headerUp { + key, val, found := strings.Cut(h, ":") + key, val = strings.TrimSpace(key), strings.TrimSpace(val) + if !found || key == "" || val == "" { + return caddy.ExitCodeFailedStartup, fmt.Errorf("header-up %d: invalid format \"%s\" (expecting \"Field: value\")", i, h) + } + reqHdr.Set(key, val) + } handler.Headers = &headers.Handler{ Request: &headers.HeaderOps{ - Set: http.Header{ - "Host": []string{"{http.reverse_proxy.upstream.hostport}"}, - }, + Set: reqHdr, + }, + } + } + + // set up header_down + headerDown, err := fs.GetStringSlice("header-down") + if err != nil { + return caddy.ExitCodeFailedStartup, fmt.Errorf("invalid header flag: %v", err) + } + if len(headerDown) > 0 { + respHdr := make(http.Header) + for i, h := range headerDown { + key, val, found := strings.Cut(h, ":") + key, val = strings.TrimSpace(key), strings.TrimSpace(val) + if !found || key == "" || val == "" { + return caddy.ExitCodeFailedStartup, fmt.Errorf("header-down %d: invalid format \"%s\" (expecting \"Field: value\")", i, h) + } + respHdr.Set(key, val) + } + if handler.Headers == nil { + handler.Headers = &headers.Handler{} + } + handler.Headers.Response = &headers.RespHeaderOps{ + HeaderOps: &headers.HeaderOps{ + Set: respHdr, }, } } + if changeHost { + if handler.Headers == nil { + handler.Headers = &headers.Handler{ + Request: &headers.HeaderOps{ + Set: http.Header{}, + }, + } + } + handler.Headers.Request.Set.Set("Host", "{http.reverse_proxy.upstream.hostport}") + } + route := caddyhttp.Route{ HandlersRaw: []json.RawMessage{ caddyconfig.JSONModuleObject(handler, "handler", "reverse_proxy", nil), @@ -173,6 +255,9 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) { Routes: caddyhttp.RouteList{route}, Listen: []string{":" + fromAddr.Port}, } + if accessLog { + server.Logs = &caddyhttp.ServerLogConfig{} + } if fromAddr.Scheme == "http" { server.AutoHTTPS = &caddyhttp.AutoHTTPSConfig{Disabled: true} @@ -191,8 +276,8 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) { tlsApp := caddytls.TLS{ Automation: &caddytls.AutomationConfig{ Policies: []*caddytls.AutomationPolicy{{ - Subjects: []string{fromAddr.Host}, - IssuersRaw: []json.RawMessage{json.RawMessage(`{"module":"internal"}`)}, + SubjectsRaw: []string{fromAddr.Host}, + IssuersRaw: []json.RawMessage{json.RawMessage(`{"module":"internal"}`)}, }}, }, } @@ -201,7 +286,8 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) { var false bool cfg := &caddy.Config{ - Admin: &caddy.AdminConfig{Disabled: true, + Admin: &caddy.AdminConfig{ + Disabled: true, Config: &caddy.ConfigSettings{ Persist: &false, }, @@ -212,7 +298,7 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) { if debug { cfg.Logging = &caddy.Logging{ Logs: map[string]*caddy.CustomLog{ - "default": {Level: zap.DebugLevel.CapitalString()}, + "default": {BaseLog: caddy.BaseLog{Level: zap.DebugLevel.CapitalString()}}, }, } } @@ -231,6 +317,3 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) { select {} } - -// reverseProxyCmdTo holds the parsed values from repeated use of the --to flag. -var reverseProxyCmdTo caddycmd.StringSlice diff --git a/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go b/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go index 799050e..a24a3ed 100644 --- a/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go +++ b/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go @@ -217,25 +217,18 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error return nil, dispenser.ArgErr() } fcgiTransport.Root = dispenser.Val() - dispenser.Delete() - dispenser.Delete() + dispenser.DeleteN(2) case "split": extensions = dispenser.RemainingArgs() - dispenser.Delete() - for range extensions { - dispenser.Delete() - } + dispenser.DeleteN(len(extensions) + 1) if len(extensions) == 0 { return nil, dispenser.ArgErr() } case "env": args := dispenser.RemainingArgs() - dispenser.Delete() - for range args { - dispenser.Delete() - } + dispenser.DeleteN(len(args) + 1) if len(args) != 2 { return nil, dispenser.ArgErr() } @@ -246,10 +239,7 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error case "index": args := dispenser.RemainingArgs() - dispenser.Delete() - for range args { - dispenser.Delete() - } + dispenser.DeleteN(len(args) + 1) if len(args) != 1 { return nil, dispenser.ArgErr() } @@ -257,10 +247,7 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error case "try_files": args := dispenser.RemainingArgs() - dispenser.Delete() - for range args { - dispenser.Delete() - } + dispenser.DeleteN(len(args) + 1) if len(args) < 1 { return nil, dispenser.ArgErr() } @@ -268,10 +255,7 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error case "resolve_root_symlink": args := dispenser.RemainingArgs() - dispenser.Delete() - for range args { - dispenser.Delete() - } + dispenser.DeleteN(len(args) + 1) fcgiTransport.ResolveRootSymlink = true case "dial_timeout": @@ -283,8 +267,7 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error return nil, dispenser.Errf("bad timeout value %s: %v", dispenser.Val(), err) } fcgiTransport.DialTimeout = caddy.Duration(dur) - dispenser.Delete() - dispenser.Delete() + dispenser.DeleteN(2) case "read_timeout": if !dispenser.NextArg() { @@ -295,8 +278,7 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error return nil, dispenser.Errf("bad timeout value %s: %v", dispenser.Val(), err) } fcgiTransport.ReadTimeout = caddy.Duration(dur) - dispenser.Delete() - dispenser.Delete() + dispenser.DeleteN(2) case "write_timeout": if !dispenser.NextArg() { @@ -307,15 +289,11 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error return nil, dispenser.Errf("bad timeout value %s: %v", dispenser.Val(), err) } fcgiTransport.WriteTimeout = caddy.Duration(dur) - dispenser.Delete() - dispenser.Delete() + dispenser.DeleteN(2) case "capture_stderr": args := dispenser.RemainingArgs() - dispenser.Delete() - for range args { - dispenser.Delete() - } + dispenser.DeleteN(len(args) + 1) fcgiTransport.CaptureStderr = true } } @@ -395,6 +373,7 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error // the rest of the config is specified by the user // using the reverse_proxy directive syntax + dispenser.Next() // consume the directive name err = rpHandler.UnmarshalCaddyfile(dispenser) if err != nil { return nil, err diff --git a/modules/caddyhttp/reverseproxy/fastcgi/client.go b/modules/caddyhttp/reverseproxy/fastcgi/client.go index ae36dd8..04513dd 100644 --- a/modules/caddyhttp/reverseproxy/fastcgi/client.go +++ b/modules/caddyhttp/reverseproxy/fastcgi/client.go @@ -251,7 +251,6 @@ func (c *client) Request(p map[string]string, req io.Reader) (resp *http.Respons // Get issues a GET request to the fcgi responder. func (c *client) Get(p map[string]string, body io.Reader, l int64) (resp *http.Response, err error) { - p["REQUEST_METHOD"] = "GET" p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10) @@ -260,7 +259,6 @@ func (c *client) Get(p map[string]string, body io.Reader, l int64) (resp *http.R // Head issues a HEAD request to the fcgi responder. func (c *client) Head(p map[string]string) (resp *http.Response, err error) { - p["REQUEST_METHOD"] = "HEAD" p["CONTENT_LENGTH"] = "0" @@ -269,7 +267,6 @@ func (c *client) Head(p map[string]string) (resp *http.Response, err error) { // Options issues an OPTIONS request to the fcgi responder. func (c *client) Options(p map[string]string) (resp *http.Response, err error) { - p["REQUEST_METHOD"] = "OPTIONS" p["CONTENT_LENGTH"] = "0" diff --git a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go index ec194e7..31febdd 100644 --- a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go +++ b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go @@ -24,13 +24,13 @@ import ( "strings" "time" - "github.com/caddyserver/caddy/v2/modules/caddyhttp" - "github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy" - "github.com/caddyserver/caddy/v2/modules/caddytls" "go.uber.org/zap" "go.uber.org/zap/zapcore" "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" + "github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy" + "github.com/caddyserver/caddy/v2/modules/caddytls" ) var noopLogger = zap.NewNop() @@ -171,6 +171,7 @@ func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) { rwc: conn, reqID: 1, logger: logger, + stderr: t.CaptureStderr, } // read/write timeouts @@ -254,9 +255,7 @@ func (t Transport) buildEnv(r *http.Request) (envVars, error) { // if we didn't get a split result here. // See https://github.com/caddyserver/caddy/issues/3718 if pathInfo == "" { - if remainder, ok := repl.GetString("http.matchers.file.remainder"); ok { - pathInfo = remainder - } + pathInfo, _ = repl.GetString("http.matchers.file.remainder") } // SCRIPT_FILENAME is the absolute path of SCRIPT_NAME @@ -286,10 +285,7 @@ func (t Transport) buildEnv(r *http.Request) (envVars, error) { reqHost = r.Host } - authUser := "" - if val, ok := repl.Get("http.auth.user.id"); ok { - authUser = val.(string) - } + authUser, _ := repl.GetString("http.auth.user.id") // Some variables are unused but cleared explicitly to prevent // the parent environment from interfering. diff --git a/modules/caddyhttp/reverseproxy/forwardauth/caddyfile.go b/modules/caddyhttp/reverseproxy/forwardauth/caddyfile.go index cecc000..8350096 100644 --- a/modules/caddyhttp/reverseproxy/forwardauth/caddyfile.go +++ b/modules/caddyhttp/reverseproxy/forwardauth/caddyfile.go @@ -129,8 +129,7 @@ func parseCaddyfile(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error) return nil, dispenser.ArgErr() } rpHandler.Rewrite.URI = dispenser.Val() - dispenser.Delete() - dispenser.Delete() + dispenser.DeleteN(2) case "copy_headers": args := dispenser.RemainingArgs() @@ -140,13 +139,11 @@ func parseCaddyfile(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error) args = append(args, dispenser.Val()) } - dispenser.Delete() // directive name + // directive name + args + dispenser.DeleteN(len(args) + 1) if hadBlock { - dispenser.Delete() // opening brace - dispenser.Delete() // closing brace - } - for range args { - dispenser.Delete() + // opening & closing brace + dispenser.DeleteN(2) } for _, headerField := range args { @@ -219,6 +216,7 @@ func parseCaddyfile(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error) // the rest of the config is specified by the user // using the reverse_proxy directive syntax + dispenser.Next() // consume the directive name err = rpHandler.UnmarshalCaddyfile(dispenser) if err != nil { return nil, err diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go index c27b24f..ad21ccb 100644 --- a/modules/caddyhttp/reverseproxy/healthchecks.go +++ b/modules/caddyhttp/reverseproxy/healthchecks.go @@ -24,12 +24,12 @@ import ( "regexp" "runtime/debug" "strconv" - "strings" "time" + "go.uber.org/zap" + "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/modules/caddyhttp" - "go.uber.org/zap" ) // HealthChecks configures active and passive health checks. @@ -106,6 +106,76 @@ type ActiveHealthChecks struct { logger *zap.Logger } +// Provision ensures that a is set up properly before use. +func (a *ActiveHealthChecks) Provision(ctx caddy.Context, h *Handler) error { + if !a.IsEnabled() { + return nil + } + + // Canonicalize the header keys ahead of time, since + // JSON unmarshaled headers may be incorrect + cleaned := http.Header{} + for key, hdrs := range a.Headers { + for _, val := range hdrs { + cleaned.Add(key, val) + } + } + a.Headers = cleaned + + h.HealthChecks.Active.logger = h.logger.Named("health_checker.active") + + timeout := time.Duration(a.Timeout) + if timeout == 0 { + timeout = 5 * time.Second + } + + if a.Path != "" { + a.logger.Warn("the 'path' option is deprecated, please use 'uri' instead!") + } + + // parse the URI string (supports path and query) + if a.URI != "" { + parsedURI, err := url.Parse(a.URI) + if err != nil { + return err + } + a.uri = parsedURI + } + + a.httpClient = &http.Client{ + Timeout: timeout, + Transport: h.Transport, + } + + for _, upstream := range h.Upstreams { + // if there's an alternative port for health-check provided in the config, + // then use it, otherwise use the port of upstream. + if a.Port != 0 { + upstream.activeHealthCheckPort = a.Port + } + } + + if a.Interval == 0 { + a.Interval = caddy.Duration(30 * time.Second) + } + + if a.ExpectBody != "" { + var err error + a.bodyRegexp, err = regexp.Compile(a.ExpectBody) + if err != nil { + return fmt.Errorf("expect_body: compiling regular expression: %v", err) + } + } + + return nil +} + +// IsEnabled checks if the active health checks have +// the minimum config necessary to be enabled. +func (a *ActiveHealthChecks) IsEnabled() bool { + return a.Path != "" || a.URI != "" || a.Port != 0 +} + // PassiveHealthChecks holds configuration related to passive // health checks (that is, health checks which occur during // the normal flow of request proxying). @@ -203,7 +273,7 @@ func (h *Handler) doActiveHealthCheckForAllHosts() { } addr.StartPort, addr.EndPort = hcp, hcp } - if upstream.LookupSRV == "" && addr.PortRangeSize() != 1 { + if addr.PortRangeSize() != 1 { h.HealthChecks.Active.logger.Error("multiple addresses (upstream must map to only one address)", zap.String("address", networkAddr), ) @@ -237,16 +307,35 @@ func (h *Handler) doActiveHealthCheckForAllHosts() { // the host's health status fails. func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstream *Upstream) error { // create the URL for the request that acts as a health check - scheme := "http" - if ht, ok := h.Transport.(TLSTransport); ok && ht.TLSEnabled() { - // this is kind of a hacky way to know if we should use HTTPS, but whatever - scheme = "https" - } u := &url.URL{ - Scheme: scheme, + Scheme: "http", Host: hostAddr, } + // split the host and port if possible, override the port if configured + host, port, err := net.SplitHostPort(hostAddr) + if err != nil { + host = hostAddr + } + if h.HealthChecks.Active.Port != 0 { + port := strconv.Itoa(h.HealthChecks.Active.Port) + u.Host = net.JoinHostPort(host, port) + } + + // this is kind of a hacky way to know if we should use HTTPS, but whatever + if tt, ok := h.Transport.(TLSTransport); ok && tt.TLSEnabled() { + u.Scheme = "https" + + // if the port is in the except list, flip back to HTTP + if ht, ok := h.Transport.(*HTTPTransport); ok { + for _, exceptPort := range ht.TLS.ExceptPorts { + if exceptPort == port { + u.Scheme = "http" + } + } + } + } + // if we have a provisioned uri, use that, otherwise use // the deprecated Path option if h.HealthChecks.Active.uri != nil { @@ -256,16 +345,6 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstre u.Path = h.HealthChecks.Active.Path } - // adjust the port, if configured to be different - if h.HealthChecks.Active.Port != 0 { - portStr := strconv.Itoa(h.HealthChecks.Active.Port) - host, _, err := net.SplitHostPort(hostAddr) - if err != nil { - host = hostAddr - } - u.Host = net.JoinHostPort(host, portStr) - } - // attach dialing information to this request, as well as context values that // may be expected by handlers of this request ctx := h.ctx.Context @@ -279,11 +358,17 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstre } ctx = context.WithValue(ctx, caddyhttp.OriginalRequestCtxKey, *req) req = req.WithContext(ctx) - for key, hdrs := range h.HealthChecks.Active.Headers { - if strings.ToLower(key) == "host" { - req.Host = h.HealthChecks.Active.Headers.Get(key) - } else { - req.Header[key] = hdrs + + // set headers, using a replacer with only globals (env vars, system info, etc.) + repl := caddy.NewReplacer() + for key, vals := range h.HealthChecks.Active.Headers { + key = repl.ReplaceAll(key, "") + if key == "Host" { + req.Host = repl.ReplaceAll(h.HealthChecks.Active.Headers.Get(key), "") + continue + } + for _, val := range vals { + req.Header.Add(key, repl.ReplaceKnown(val, "")) } } diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go index a973ecb..83a39d8 100644 --- a/modules/caddyhttp/reverseproxy/hosts.go +++ b/modules/caddyhttp/reverseproxy/hosts.go @@ -17,8 +17,8 @@ package reverseproxy import ( "context" "fmt" - "net" "net/http" + "net/netip" "strconv" "sync/atomic" @@ -47,15 +47,6 @@ type Upstream struct { // backends is down. Also be aware of open proxy vulnerabilities. Dial string `json:"dial,omitempty"` - // DEPRECATED: Use the SRVUpstreams module instead - // (http.reverse_proxy.upstreams.srv). This field will be - // removed in a future version of Caddy. TODO: Remove this field. - // - // If DNS SRV records are used for service discovery with this - // upstream, specify the DNS name for which to look up SRV - // records here, instead of specifying a dial address. - LookupSRV string `json:"lookup_srv,omitempty"` - // The maximum number of simultaneous requests to allow to // this upstream. If set, overrides the global passive health // check UnhealthyRequestCount value. @@ -72,12 +63,10 @@ type Upstream struct { unhealthy int32 // accessed atomically; status from active health checker } -func (u Upstream) String() string { - if u.LookupSRV != "" { - return u.LookupSRV - } - return u.Dial -} +// (pointer receiver necessary to avoid a race condition, since +// copying the Upstream reads the 'unhealthy' field which is +// accessed atomically) +func (u *Upstream) String() string { return u.Dial } // Available returns true if the remote host // is available to receive requests. This is @@ -109,35 +98,21 @@ func (u *Upstream) Full() bool { } // fillDialInfo returns a filled DialInfo for upstream u, using the request -// context. If the upstream has a SRV lookup configured, that is done and a -// returned address is chosen; otherwise, the upstream's regular dial address -// field is used. Note that the returned value is not a pointer. +// context. Note that the returned value is not a pointer. func (u *Upstream) fillDialInfo(r *http.Request) (DialInfo, error) { repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) var addr caddy.NetworkAddress - if u.LookupSRV != "" { - // perform DNS lookup for SRV records and choose one - TODO: deprecated - srvName := repl.ReplaceAll(u.LookupSRV, "") - _, records, err := net.DefaultResolver.LookupSRV(r.Context(), "", "", srvName) - if err != nil { - return DialInfo{}, err - } - addr.Network = "tcp" - addr.Host = records[0].Target - addr.StartPort, addr.EndPort = uint(records[0].Port), uint(records[0].Port) - } else { - // use provided dial address - var err error - dial := repl.ReplaceAll(u.Dial, "") - addr, err = caddy.ParseNetworkAddress(dial) - if err != nil { - return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", u.Dial, dial, err) - } - if numPorts := addr.PortRangeSize(); numPorts != 1 { - return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d", - u.Dial, dial, numPorts) - } + // use provided dial address + var err error + dial := repl.ReplaceAll(u.Dial, "") + addr, err = caddy.ParseNetworkAddress(dial) + if err != nil { + return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", u.Dial, dial, err) + } + if numPorts := addr.PortRangeSize(); numPorts != 1 { + return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d", + u.Dial, dial, numPorts) } return DialInfo{ @@ -259,3 +234,13 @@ var hosts = caddy.NewUsagePool() // dialInfoVarKey is the key used for the variable that holds // the dial info for the upstream connection. const dialInfoVarKey = "reverse_proxy.dial_info" + +// proxyProtocolInfoVarKey is the key used for the variable that holds +// the proxy protocol info for the upstream connection. +const proxyProtocolInfoVarKey = "reverse_proxy.proxy_protocol_info" + +// ProxyProtocolInfo contains information needed to write proxy protocol to a +// connection to an upstream host. +type ProxyProtocolInfo struct { + AddrPort netip.AddrPort +} diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index ec5d2f2..9290f7e 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -28,10 +28,13 @@ import ( "strings" "time" - "github.com/caddyserver/caddy/v2" - "github.com/caddyserver/caddy/v2/modules/caddytls" + "github.com/mastercactapus/proxyprotocol" "go.uber.org/zap" "golang.org/x/net/http2" + + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" + "github.com/caddyserver/caddy/v2/modules/caddytls" ) func init() { @@ -64,6 +67,10 @@ type HTTPTransport struct { // Maximum number of connections per host. Default: 0 (no limit) MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` + // If non-empty, which PROXY protocol version to send when + // connecting to an upstream. Default: off. + ProxyProtocol string `json:"proxy_protocol,omitempty"` + // How long to wait before timing out trying to connect to // an upstream. Default: `3s`. DialTimeout caddy.Duration `json:"dial_timeout,omitempty"` @@ -172,12 +179,19 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e } } - // Set up the dialer to pull the correct information from the context dialContext := func(ctx context.Context, network, address string) (net.Conn, error) { - // the proper dialing information should be embedded into the request's context + // For unix socket upstreams, we need to recover the dial info from + // the request's context, because the Host on the request's URL + // will have been modified by directing the request, overwriting + // the unix socket filename. + // Also, we need to avoid overwriting the address at this point + // when not necessary, because http.ProxyFromEnvironment may have + // modified the address according to the user's env proxy config. if dialInfo, ok := GetDialInfo(ctx); ok { - network = dialInfo.Network - address = dialInfo.Address + if strings.HasPrefix(dialInfo.Network, "unix") { + network = dialInfo.Network + address = dialInfo.Address + } } conn, err := dialer.DialContext(ctx, network, address) @@ -188,8 +202,59 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e return nil, DialError{err} } - // if read/write timeouts are configured and this is a TCP connection, enforce the timeouts - // by wrapping the connection with our own type + if h.ProxyProtocol != "" { + proxyProtocolInfo, ok := caddyhttp.GetVar(ctx, proxyProtocolInfoVarKey).(ProxyProtocolInfo) + if !ok { + return nil, fmt.Errorf("failed to get proxy protocol info from context") + } + + // The src and dst have to be of the some address family. As we don't know the original + // dst address (it's kind of impossible to know) and this address is generelly of very + // little interest, we just set it to all zeros. + var destIP net.IP + switch { + case proxyProtocolInfo.AddrPort.Addr().Is4(): + destIP = net.IPv4zero + case proxyProtocolInfo.AddrPort.Addr().Is6(): + destIP = net.IPv6zero + default: + return nil, fmt.Errorf("unexpected remote addr type in proxy protocol info") + } + + // TODO: We should probably migrate away from net.IP to use netip.Addr, + // but due to the upstream dependency, we can't do that yet. + switch h.ProxyProtocol { + case "v1": + header := proxyprotocol.HeaderV1{ + SrcIP: net.IP(proxyProtocolInfo.AddrPort.Addr().AsSlice()), + SrcPort: int(proxyProtocolInfo.AddrPort.Port()), + DestIP: destIP, + DestPort: 0, + } + caddyCtx.Logger().Debug("sending proxy protocol header v1", zap.Any("header", header)) + _, err = header.WriteTo(conn) + case "v2": + header := proxyprotocol.HeaderV2{ + Command: proxyprotocol.CmdProxy, + Src: &net.TCPAddr{IP: net.IP(proxyProtocolInfo.AddrPort.Addr().AsSlice()), Port: int(proxyProtocolInfo.AddrPort.Port())}, + Dest: &net.TCPAddr{IP: destIP, Port: 0}, + } + caddyCtx.Logger().Debug("sending proxy protocol header v2", zap.Any("header", header)) + _, err = header.WriteTo(conn) + default: + return nil, fmt.Errorf("unexpected proxy protocol version") + } + + if err != nil { + // identify this error as one that occurred during + // dialing, which can be important when trying to + // decide whether to retry a request + return nil, DialError{err} + } + } + + // if read/write timeouts are configured and this is a TCP connection, + // enforce the timeouts by wrapping the connection with our own type if tcpConn, ok := conn.(*net.TCPConn); ok && (h.ReadTimeout > 0 || h.WriteTimeout > 0) { conn = &tcpRWTimeoutConn{ TCPConn: tcpConn, @@ -203,6 +268,7 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e } rt := &http.Transport{ + Proxy: http.ProxyFromEnvironment, DialContext: dialContext, MaxConnsPerHost: h.MaxConnsPerHost, ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout), @@ -231,6 +297,14 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e rt.IdleConnTimeout = time.Duration(h.KeepAlive.IdleConnTimeout) } + // The proxy protocol header can only be sent once right after opening the connection. + // So single connection must not be used for multiple requests, which can potentially + // come from different clients. + if !rt.DisableKeepAlives && h.ProxyProtocol != "" { + caddyCtx.Logger().Warn("disabling keepalives, they are incompatible with using PROXY protocol") + rt.DisableKeepAlives = true + } + if h.Compression != nil { rt.DisableCompression = !*h.Compression } @@ -452,10 +526,11 @@ func (t TLSConfig) MakeTLSClientConfig(ctx caddy.Context) (*tls.Config, error) { return nil, fmt.Errorf("managing client certificate: %v", err) } cfg.GetClientCertificate = func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { - certs := tlsApp.AllMatchingCertificates(t.ClientCertificateAutomate) + certs := caddytls.AllMatchingCertificates(t.ClientCertificateAutomate) var err error for _, cert := range certs { - err = cri.SupportsCertificate(&cert.Certificate) + certCertificate := cert.Certificate // avoid taking address of iteration variable (gosec warning) + err = cri.SupportsCertificate(&certCertificate) if err == nil { return &cert.Certificate, nil } diff --git a/modules/caddyhttp/reverseproxy/metrics.go b/modules/caddyhttp/reverseproxy/metrics.go index 4272bc4..d3c8ee0 100644 --- a/modules/caddyhttp/reverseproxy/metrics.go +++ b/modules/caddyhttp/reverseproxy/metrics.go @@ -39,6 +39,8 @@ func newMetricsUpstreamsHealthyUpdater(handler *Handler) *metricsUpstreamsHealth initReverseProxyMetrics(handler) }) + reverseProxyMetrics.upstreamsHealthy.Reset() + return &metricsUpstreamsHealthyUpdater{handler} } diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 1449785..08be40d 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -27,30 +27,23 @@ import ( "net/netip" "net/textproto" "net/url" - "regexp" - "runtime" "strconv" "strings" "sync" "time" + "go.uber.org/zap" + "golang.org/x/net/http/httpguts" + "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" "github.com/caddyserver/caddy/v2/modules/caddyevents" "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/caddyserver/caddy/v2/modules/caddyhttp/headers" "github.com/caddyserver/caddy/v2/modules/caddyhttp/rewrite" - "go.uber.org/zap" - "golang.org/x/net/http/httpguts" ) -var supports1xx bool - func init() { - // Caddy requires at least Go 1.18, but Early Hints requires Go 1.19; thus we can simply check for 1.18 in version string - // TODO: remove this once our minimum Go version is 1.19 - supports1xx = !strings.Contains(runtime.Version(), "go1.18") - caddy.RegisterModule(Handler{}) } @@ -158,6 +151,19 @@ type Handler struct { // could be useful if the backend has tighter memory constraints. ResponseBuffers int64 `json:"response_buffers,omitempty"` + // If nonzero, streaming requests such as WebSockets will be + // forcibly closed at the end of the timeout. Default: no timeout. + StreamTimeout caddy.Duration `json:"stream_timeout,omitempty"` + + // If nonzero, streaming requests such as WebSockets will not be + // closed when the proxy config is unloaded, and instead the stream + // will remain open until the delay is complete. In other words, + // enabling this prevents streams from closing when Caddy's config + // is reloaded. Enabling this may be a good idea to avoid a thundering + // herd of reconnecting clients which had their connections closed + // by the previous config closing. Default: no delay. + StreamCloseDelay caddy.Duration `json:"stream_close_delay,omitempty"` + // If configured, rewrites the copy of the upstream request. // Allows changing the request method and URI (path and query). // Since the rewrite is applied to the copy, it does not persist @@ -185,6 +191,13 @@ type Handler struct { // - `{http.reverse_proxy.header.*}` The headers from the response HandleResponse []caddyhttp.ResponseHandler `json:"handle_response,omitempty"` + // If set, the proxy will write very detailed logs about its + // inner workings. Enable this only when debugging, as it + // will produce a lot of output. + // + // EXPERIMENTAL: This feature is subject to change or removal. + VerboseLogs bool `json:"verbose_logs,omitempty"` + Transport http.RoundTripper `json:"-"` CB CircuitBreaker `json:"-"` DynamicUpstreams UpstreamSource `json:"-"` @@ -199,8 +212,9 @@ type Handler struct { handleResponseSegments []*caddyfile.Dispenser // Stores upgraded requests (hijacked connections) for proper cleanup - connections map[io.ReadWriteCloser]openConnection - connectionsMu *sync.Mutex + connections map[io.ReadWriteCloser]openConnection + connectionsCloseTimer *time.Timer + connectionsMu *sync.Mutex ctx caddy.Context logger *zap.Logger @@ -243,20 +257,6 @@ func (h *Handler) Provision(ctx caddy.Context) error { h.logger.Warn("UNLIMITED BUFFERING: buffering is enabled without any cap on buffer size, which can result in OOM crashes") } - // verify SRV compatibility - TODO: LookupSRV deprecated; will be removed - for i, v := range h.Upstreams { - if v.LookupSRV == "" { - continue - } - h.logger.Warn("DEPRECATED: lookup_srv: will be removed in a near-future version of Caddy; use the http.reverse_proxy.upstreams.srv module instead") - if h.HealthChecks != nil && h.HealthChecks.Active != nil { - return fmt.Errorf(`upstream: lookup_srv is incompatible with active health checks: %d: {"dial": %q, "lookup_srv": %q}`, i, v.Dial, v.LookupSRV) - } - if v.Dial != "" { - return fmt.Errorf(`upstream: specifying dial address is incompatible with lookup_srv: %d: {"dial": %q, "lookup_srv": %q}`, i, v.Dial, v.LookupSRV) - } - } - // start by loading modules if h.TransportRaw != nil { mod, err := ctx.LoadModule(h, "TransportRaw") @@ -363,62 +363,22 @@ func (h *Handler) Provision(ctx caddy.Context) error { if h.HealthChecks != nil { // set defaults on passive health checks, if necessary if h.HealthChecks.Passive != nil { - if h.HealthChecks.Passive.FailDuration > 0 && h.HealthChecks.Passive.MaxFails == 0 { + h.HealthChecks.Passive.logger = h.logger.Named("health_checker.passive") + if h.HealthChecks.Passive.MaxFails == 0 { h.HealthChecks.Passive.MaxFails = 1 } } // if active health checks are enabled, configure them and start a worker - if h.HealthChecks.Active != nil && (h.HealthChecks.Active.Path != "" || - h.HealthChecks.Active.URI != "" || - h.HealthChecks.Active.Port != 0) { - - h.HealthChecks.Active.logger = h.logger.Named("health_checker.active") - - timeout := time.Duration(h.HealthChecks.Active.Timeout) - if timeout == 0 { - timeout = 5 * time.Second - } - - if h.HealthChecks.Active.Path != "" { - h.HealthChecks.Active.logger.Warn("the 'path' option is deprecated, please use 'uri' instead!") - } - - // parse the URI string (supports path and query) - if h.HealthChecks.Active.URI != "" { - parsedURI, err := url.Parse(h.HealthChecks.Active.URI) - if err != nil { - return err - } - h.HealthChecks.Active.uri = parsedURI - } - - h.HealthChecks.Active.httpClient = &http.Client{ - Timeout: timeout, - Transport: h.Transport, - } - - for _, upstream := range h.Upstreams { - // if there's an alternative port for health-check provided in the config, - // then use it, otherwise use the port of upstream. - if h.HealthChecks.Active.Port != 0 { - upstream.activeHealthCheckPort = h.HealthChecks.Active.Port - } + if h.HealthChecks.Active != nil { + err := h.HealthChecks.Active.Provision(ctx, h) + if err != nil { + return err } - if h.HealthChecks.Active.Interval == 0 { - h.HealthChecks.Active.Interval = caddy.Duration(30 * time.Second) + if h.HealthChecks.Active.IsEnabled() { + go h.activeHealthChecker() } - - if h.HealthChecks.Active.ExpectBody != "" { - var err error - h.HealthChecks.Active.bodyRegexp, err = regexp.Compile(h.HealthChecks.Active.ExpectBody) - if err != nil { - return fmt.Errorf("expect_body: compiling regular expression: %v", err) - } - } - - go h.activeHealthChecker() } } @@ -438,25 +398,7 @@ func (h *Handler) Provision(ctx caddy.Context) error { // Cleanup cleans up the resources made by h. func (h *Handler) Cleanup() error { - // close hijacked connections (both to client and backend) - var err error - h.connectionsMu.Lock() - for _, oc := range h.connections { - if oc.gracefulClose != nil { - // this is potentially blocking while we have the lock on the connections - // map, but that should be OK since the server has in theory shut down - // and we are no longer using the connections map - gracefulErr := oc.gracefulClose() - if gracefulErr != nil && err == nil { - err = gracefulErr - } - } - closeErr := oc.conn.Close() - if closeErr != nil && err == nil { - err = closeErr - } - } - h.connectionsMu.Unlock() + err := h.cleanupConnections() // remove hosts from our config from the pool for _, upstream := range h.Upstreams { @@ -517,7 +459,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht // It returns true when the loop is done and should break; false otherwise. The error value returned should // be assigned to the proxyErr value for the next iteration of the loop (or the error handled after break). func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w http.ResponseWriter, proxyErr error, start time.Time, retries int, - repl *caddy.Replacer, reqHeader http.Header, reqHost string, next caddyhttp.Handler) (bool, error) { + repl *caddy.Replacer, reqHeader http.Header, reqHost string, next caddyhttp.Handler, +) (bool, error) { // get the updated list of upstreams upstreams := h.Upstreams if h.DynamicUpstreams != nil { @@ -544,7 +487,7 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h upstream := h.LoadBalancing.SelectionPolicy.Select(upstreams, r, w) if upstream == nil { if proxyErr == nil { - proxyErr = caddyhttp.Error(http.StatusServiceUnavailable, fmt.Errorf("no upstreams available")) + proxyErr = caddyhttp.Error(http.StatusServiceUnavailable, noUpstreamsAvailable) } if !h.LoadBalancing.tryAgain(h.ctx, start, retries, proxyErr, r) { return true, proxyErr @@ -646,7 +589,8 @@ func (h Handler) prepareRequest(req *http.Request, repl *caddy.Replacer) (*http. // feature if absolutely required, if read timeouts are // set, and if body size is limited if h.RequestBuffers != 0 && req.Body != nil { - req.Body, _ = h.bufferedBody(req.Body, h.RequestBuffers) + req.Body, req.ContentLength = h.bufferedBody(req.Body, h.RequestBuffers) + req.Header.Set("Content-Length", strconv.FormatInt(req.ContentLength, 10)) } if req.ContentLength == 0 { @@ -687,8 +631,24 @@ func (h Handler) prepareRequest(req *http.Request, repl *caddy.Replacer) (*http. req.Header.Set("Upgrade", reqUpType) } + // Set up the PROXY protocol info + address := caddyhttp.GetVar(req.Context(), caddyhttp.ClientIPVarKey).(string) + addrPort, err := netip.ParseAddrPort(address) + if err != nil { + // OK; probably didn't have a port + addr, err := netip.ParseAddr(address) + if err != nil { + // Doesn't seem like a valid ip address at all + } else { + // Ok, only the port was missing + addrPort = netip.AddrPortFrom(addr, 0) + } + } + proxyProtocolInfo := ProxyProtocolInfo{AddrPort: addrPort} + caddyhttp.SetVar(req.Context(), proxyProtocolInfoVarKey, proxyProtocolInfo) + // Add the supported X-Forwarded-* headers - err := h.addForwardedHeaders(req) + err = h.addForwardedHeaders(req) if err != nil { return nil, err } @@ -795,25 +755,23 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe server := req.Context().Value(caddyhttp.ServerCtxKey).(*caddyhttp.Server) shouldLogCredentials := server.Logs != nil && server.Logs.ShouldLogCredentials - if supports1xx { - // Forward 1xx status codes, backported from https://github.com/golang/go/pull/53164 - trace := &httptrace.ClientTrace{ - Got1xxResponse: func(code int, header textproto.MIMEHeader) error { - h := rw.Header() - copyHeader(h, http.Header(header)) - rw.WriteHeader(code) - - // Clear headers coming from the backend - // (it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses) - for k := range header { - delete(h, k) - } + // Forward 1xx status codes, backported from https://github.com/golang/go/pull/53164 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + h := rw.Header() + copyHeader(h, http.Header(header)) + rw.WriteHeader(code) + + // Clear headers coming from the backend + // (it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses) + for k := range header { + delete(h, k) + } - return nil - }, - } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + return nil + }, } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) // if FlushInterval is explicitly configured to -1 (i.e. flush continuously to achieve // low-latency streaming), don't let the transport cancel the request if the client @@ -821,7 +779,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe // regardless, and we should expect client disconnection in low-latency streaming // scenarios (see issue #4922) if h.FlushInterval == -1 { - req = req.WithContext(ignoreClientGoneContext{req.Context(), h.ctx.Done()}) + req = req.WithContext(ignoreClientGoneContext{req.Context()}) } // do the round-trip; emit debug log with values we know are @@ -897,12 +855,6 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe break } - // otherwise, if there are any routes configured, execute those as the - // actual response instead of what we got from the proxy backend - if len(rh.Routes) == 0 { - continue - } - // set up the replacer so that parts of the original response can be // used for routing decisions for field, value := range res.Header { @@ -911,7 +863,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe repl.Set("http.reverse_proxy.status_code", res.StatusCode) repl.Set("http.reverse_proxy.status_text", res.Status) - h.logger.Debug("handling response", zap.Int("handler", i)) + logger.Debug("handling response", zap.Int("handler", i)) // we make some data available via request context to child routes // so that they may inherit some options and functions from the @@ -956,7 +908,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe } // finalizeResponse prepares and copies the response. -func (h Handler) finalizeResponse( +func (h *Handler) finalizeResponse( rw http.ResponseWriter, req *http.Request, res *http.Response, @@ -998,15 +950,21 @@ func (h Handler) finalizeResponse( } rw.WriteHeader(res.StatusCode) + if h.VerboseLogs { + logger.Debug("wrote header") + } - err := h.copyResponse(rw, res.Body, h.flushInterval(req, res)) - res.Body.Close() // close now, instead of defer, to populate res.Trailer + err := h.copyResponse(rw, res.Body, h.flushInterval(req, res), logger) + errClose := res.Body.Close() // close now, instead of defer, to populate res.Trailer + if h.VerboseLogs || errClose != nil { + logger.Debug("closed response body from upstream", zap.Error(errClose)) + } if err != nil { // we're streaming the response and we've already written headers, so // there's nothing an error handler can do to recover at this point; // the standard lib's proxy panics at this point, but we'll just log // the error and abort the stream here - h.logger.Error("aborting with incomplete response", zap.Error(err)) + logger.Error("aborting with incomplete response", zap.Error(err)) return nil } @@ -1014,9 +972,8 @@ func (h Handler) finalizeResponse( // Force chunking if we saw a response trailer. // This prevents net/http from calculating the length for short // bodies and adding a Content-Length. - if fl, ok := rw.(http.Flusher); ok { - fl.Flush() - } + //nolint:bodyclose + http.NewResponseController(rw).Flush() } // total duration spent proxying, including writing response body @@ -1035,6 +992,10 @@ func (h Handler) finalizeResponse( } } + if h.VerboseLogs { + logger.Debug("response finalized") + } + return nil } @@ -1066,17 +1027,23 @@ func (lb LoadBalancing) tryAgain(ctx caddy.Context, start time.Time, retries int // should be safe to retry, since without a connection, no // HTTP request can be transmitted; but if the error is not // specifically a dialer error, we need to be careful - if _, ok := proxyErr.(DialError); proxyErr != nil && !ok { + if proxyErr != nil { + _, isDialError := proxyErr.(DialError) + herr, isHandlerError := proxyErr.(caddyhttp.HandlerError) + // if the error occurred after a connection was established, // we have to assume the upstream received the request, and // retries need to be carefully decided, because some requests // are not idempotent - if lb.RetryMatch == nil && req.Method != "GET" { - // by default, don't retry requests if they aren't GET - return false - } - if !lb.RetryMatch.AnyMatch(req) { - return false + if !isDialError && !(isHandlerError && errors.Is(herr, noUpstreamsAvailable)) { + if lb.RetryMatch == nil && req.Method != "GET" { + // by default, don't retry requests if they aren't GET + return false + } + + if !lb.RetryMatch.AnyMatch(req) { + return false + } } } @@ -1128,12 +1095,11 @@ func (h Handler) provisionUpstream(upstream *Upstream) { // without MaxRequests), copy the value into this upstream, since the // value in the upstream (MaxRequests) is what is used during // availability checks - if h.HealthChecks != nil && h.HealthChecks.Passive != nil { - h.HealthChecks.Passive.logger = h.logger.Named("health_checker.passive") - if h.HealthChecks.Passive.UnhealthyRequestCount > 0 && - upstream.MaxRequests == 0 { - upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount - } + if h.HealthChecks != nil && + h.HealthChecks.Passive != nil && + h.HealthChecks.Passive.UnhealthyRequestCount > 0 && + upstream.MaxRequests == 0 { + upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount } // upstreams need independent access to the passive @@ -1450,21 +1416,36 @@ type handleResponseContext struct { // ignoreClientGoneContext is a special context.Context type // intended for use when doing a RoundTrip where you don't // want a client disconnection to cancel the request during -// the roundtrip. Set its done field to a Done() channel -// of a context that doesn't get canceled when the client -// disconnects, such as caddy.Context.Done() instead. +// the roundtrip. +// This context clears cancellation, error, and deadline methods, +// but still allows values to pass through from its embedded +// context. +// +// TODO: This can be replaced with context.WithoutCancel once +// the minimum required version of Go is 1.21. type ignoreClientGoneContext struct { context.Context - done <-chan struct{} } -func (c ignoreClientGoneContext) Done() <-chan struct{} { return c.done } +func (c ignoreClientGoneContext) Deadline() (deadline time.Time, ok bool) { + return +} + +func (c ignoreClientGoneContext) Done() <-chan struct{} { + return nil +} + +func (c ignoreClientGoneContext) Err() error { + return nil +} // proxyHandleResponseContextCtxKey is the context key for the active proxy handler // so that handle_response routes can inherit some config options // from the proxy handler. const proxyHandleResponseContextCtxKey caddy.CtxKey = "reverse_proxy_handle_response_context" +var noUpstreamsAvailable = fmt.Errorf("no upstreams available") + // Interface guards var ( _ caddy.Provisioner = (*Handler)(nil) diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go index 0b7f50c..acb069a 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go @@ -18,17 +18,20 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/hex" + "encoding/json" "fmt" "hash/fnv" weakrand "math/rand" "net" "net/http" "strconv" + "strings" "sync/atomic" - "time" "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" ) func init() { @@ -36,13 +39,14 @@ func init() { caddy.RegisterModule(RandomChoiceSelection{}) caddy.RegisterModule(LeastConnSelection{}) caddy.RegisterModule(RoundRobinSelection{}) + caddy.RegisterModule(WeightedRoundRobinSelection{}) caddy.RegisterModule(FirstSelection{}) caddy.RegisterModule(IPHashSelection{}) + caddy.RegisterModule(ClientIPHashSelection{}) caddy.RegisterModule(URIHashSelection{}) + caddy.RegisterModule(QueryHashSelection{}) caddy.RegisterModule(HeaderHashSelection{}) caddy.RegisterModule(CookieHashSelection{}) - - weakrand.Seed(time.Now().UTC().UnixNano()) } // RandomSelection is a policy that selects @@ -72,6 +76,90 @@ func (r *RandomSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { return nil } +// WeightedRoundRobinSelection is a policy that selects +// a host based on weighted round-robin ordering. +type WeightedRoundRobinSelection struct { + // The weight of each upstream in order, + // corresponding with the list of upstreams configured. + Weights []int `json:"weights,omitempty"` + index uint32 + totalWeight int +} + +// CaddyModule returns the Caddy module information. +func (WeightedRoundRobinSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "http.reverse_proxy.selection_policies.weighted_round_robin", + New: func() caddy.Module { + return new(WeightedRoundRobinSelection) + }, + } +} + +// UnmarshalCaddyfile sets up the module from Caddyfile tokens. +func (r *WeightedRoundRobinSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + for d.Next() { + args := d.RemainingArgs() + if len(args) == 0 { + return d.ArgErr() + } + + for _, weight := range args { + weightInt, err := strconv.Atoi(weight) + if err != nil { + return d.Errf("invalid weight value '%s': %v", weight, err) + } + if weightInt < 1 { + return d.Errf("invalid weight value '%s': weight should be non-zero and positive", weight) + } + r.Weights = append(r.Weights, weightInt) + } + } + return nil +} + +// Provision sets up r. +func (r *WeightedRoundRobinSelection) Provision(ctx caddy.Context) error { + for _, weight := range r.Weights { + r.totalWeight += weight + } + return nil +} + +// Select returns an available host, if any. +func (r *WeightedRoundRobinSelection) Select(pool UpstreamPool, _ *http.Request, _ http.ResponseWriter) *Upstream { + if len(pool) == 0 { + return nil + } + if len(r.Weights) < 2 { + return pool[0] + } + var index, totalWeight int + currentWeight := int(atomic.AddUint32(&r.index, 1)) % r.totalWeight + for i, weight := range r.Weights { + totalWeight += weight + if currentWeight < totalWeight { + index = i + break + } + } + + upstreams := make([]*Upstream, 0, len(r.Weights)) + for _, upstream := range pool { + if !upstream.Available() { + continue + } + upstreams = append(upstreams, upstream) + if len(upstreams) == cap(upstreams) { + break + } + } + if len(upstreams) == 0 { + return nil + } + return upstreams[index%len(upstreams)] +} + // RandomChoiceSelection is a policy that selects // two or more available hosts at random, then // chooses the one with the least load. @@ -181,7 +269,7 @@ func (LeastConnSelection) Select(pool UpstreamPool, _ *http.Request, _ http.Resp // sample: https://en.wikipedia.org/wiki/Reservoir_sampling if numReqs == leastReqs { count++ - if (weakrand.Int() % count) == 0 { //nolint:gosec + if count == 1 || (weakrand.Int()%count) == 0 { //nolint:gosec bestHost = host } } @@ -303,6 +391,39 @@ func (r *IPHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { return nil } +// ClientIPHashSelection is a policy that selects a host +// based on hashing the client IP of the request, as determined +// by the HTTP app's trusted proxies settings. +type ClientIPHashSelection struct{} + +// CaddyModule returns the Caddy module information. +func (ClientIPHashSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "http.reverse_proxy.selection_policies.client_ip_hash", + New: func() caddy.Module { return new(ClientIPHashSelection) }, + } +} + +// Select returns an available host, if any. +func (ClientIPHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream { + address := caddyhttp.GetVar(req.Context(), caddyhttp.ClientIPVarKey).(string) + clientIP, _, err := net.SplitHostPort(address) + if err != nil { + clientIP = address // no port + } + return hostByHashing(pool, clientIP) +} + +// UnmarshalCaddyfile sets up the module from Caddyfile tokens. +func (r *ClientIPHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + for d.Next() { + if d.NextArg() { + return d.ArgErr() + } + } + return nil +} + // URIHashSelection is a policy that selects a // host by hashing the request URI. type URIHashSelection struct{} @@ -330,11 +451,95 @@ func (r *URIHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { return nil } +// QueryHashSelection is a policy that selects +// a host based on a given request query parameter. +type QueryHashSelection struct { + // The query key whose value is to be hashed and used for upstream selection. + Key string `json:"key,omitempty"` + + // The fallback policy to use if the query key is not present. Defaults to `random`. + FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"` + fallback Selector +} + +// CaddyModule returns the Caddy module information. +func (QueryHashSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "http.reverse_proxy.selection_policies.query", + New: func() caddy.Module { return new(QueryHashSelection) }, + } +} + +// Provision sets up the module. +func (s *QueryHashSelection) Provision(ctx caddy.Context) error { + if s.Key == "" { + return fmt.Errorf("query key is required") + } + if s.FallbackRaw == nil { + s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil) + } + mod, err := ctx.LoadModule(s, "FallbackRaw") + if err != nil { + return fmt.Errorf("loading fallback selection policy: %s", err) + } + s.fallback = mod.(Selector) + return nil +} + +// Select returns an available host, if any. +func (s QueryHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream { + // Since the query may have multiple values for the same key, + // we'll join them to avoid a problem where the user can control + // the upstream that the request goes to by sending multiple values + // for the same key, when the upstream only considers the first value. + // Keep in mind that a client changing the order of the values may + // affect which upstream is selected, but this is a semantically + // different request, because the order of the values is significant. + vals := strings.Join(req.URL.Query()[s.Key], ",") + if vals == "" { + return s.fallback.Select(pool, req, nil) + } + return hostByHashing(pool, vals) +} + +// UnmarshalCaddyfile sets up the module from Caddyfile tokens. +func (s *QueryHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + for d.Next() { + if !d.NextArg() { + return d.ArgErr() + } + s.Key = d.Val() + } + for nesting := d.Nesting(); d.NextBlock(nesting); { + switch d.Val() { + case "fallback": + if !d.NextArg() { + return d.ArgErr() + } + if s.FallbackRaw != nil { + return d.Err("fallback selection policy already specified") + } + mod, err := loadFallbackPolicy(d) + if err != nil { + return err + } + s.FallbackRaw = mod + default: + return d.Errf("unrecognized option '%s'", d.Val()) + } + } + return nil +} + // HeaderHashSelection is a policy that selects // a host based on a given request header. type HeaderHashSelection struct { // The HTTP header field whose value is to be hashed and used for upstream selection. Field string `json:"field,omitempty"` + + // The fallback policy to use if the header is not present. Defaults to `random`. + FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"` + fallback Selector } // CaddyModule returns the Caddy module information. @@ -345,12 +550,24 @@ func (HeaderHashSelection) CaddyModule() caddy.ModuleInfo { } } -// Select returns an available host, if any. -func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream { +// Provision sets up the module. +func (s *HeaderHashSelection) Provision(ctx caddy.Context) error { if s.Field == "" { - return nil + return fmt.Errorf("header field is required") + } + if s.FallbackRaw == nil { + s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil) } + mod, err := ctx.LoadModule(s, "FallbackRaw") + if err != nil { + return fmt.Errorf("loading fallback selection policy: %s", err) + } + s.fallback = mod.(Selector) + return nil +} +// Select returns an available host, if any. +func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream { // The Host header should be obtained from the req.Host field // since net/http removes it from the header map. if s.Field == "Host" && req.Host != "" { @@ -359,7 +576,7 @@ func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http val := req.Header.Get(s.Field) if val == "" { - return RandomSelection{}.Select(pool, req, nil) + return s.fallback.Select(pool, req, nil) } return hostByHashing(pool, val) } @@ -372,6 +589,24 @@ func (s *HeaderHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } s.Field = d.Val() } + for nesting := d.Nesting(); d.NextBlock(nesting); { + switch d.Val() { + case "fallback": + if !d.NextArg() { + return d.ArgErr() + } + if s.FallbackRaw != nil { + return d.Err("fallback selection policy already specified") + } + mod, err := loadFallbackPolicy(d) + if err != nil { + return err + } + s.FallbackRaw = mod + default: + return d.Errf("unrecognized option '%s'", d.Val()) + } + } return nil } @@ -382,6 +617,10 @@ type CookieHashSelection struct { Name string `json:"name,omitempty"` // Secret to hash (Hmac256) chosen upstream in cookie Secret string `json:"secret,omitempty"` + + // The fallback policy to use if the cookie is not present. Defaults to `random`. + FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"` + fallback Selector } // CaddyModule returns the Caddy module information. @@ -392,15 +631,48 @@ func (CookieHashSelection) CaddyModule() caddy.ModuleInfo { } } -// Select returns an available host, if any. -func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http.ResponseWriter) *Upstream { +// Provision sets up the module. +func (s *CookieHashSelection) Provision(ctx caddy.Context) error { if s.Name == "" { s.Name = "lb" } + if s.FallbackRaw == nil { + s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil) + } + mod, err := ctx.LoadModule(s, "FallbackRaw") + if err != nil { + return fmt.Errorf("loading fallback selection policy: %s", err) + } + s.fallback = mod.(Selector) + return nil +} + +// Select returns an available host, if any. +func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http.ResponseWriter) *Upstream { + // selects a new Host using the fallback policy (typically random) + // and write a sticky session cookie to the response. + selectNewHost := func() *Upstream { + upstream := s.fallback.Select(pool, req, w) + if upstream == nil { + return nil + } + sha, err := hashCookie(s.Secret, upstream.Dial) + if err != nil { + return upstream + } + http.SetCookie(w, &http.Cookie{ + Name: s.Name, + Value: sha, + Path: "/", + Secure: false, + }) + return upstream + } + cookie, err := req.Cookie(s.Name) - // If there's no cookie, select new random host + // If there's no cookie, select a host using the fallback policy if err != nil || cookie == nil { - return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name) + return selectNewHost() } // If the cookie is present, loop over the available upstreams until we find a match cookieValue := cookie.Value @@ -413,13 +685,15 @@ func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http return upstream } } - // If there is no matching host, select new random host - return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name) + // If there is no matching host, select a host using the fallback policy + return selectNewHost() } // UnmarshalCaddyfile sets up the module from Caddyfile tokens. Syntax: // -// lb_policy cookie [<name> [<secret>]] +// lb_policy cookie [<name> [<secret>]] { +// fallback <policy> +// } // // By default name is `lb` func (s *CookieHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { @@ -434,22 +708,25 @@ func (s *CookieHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { default: return d.ArgErr() } - return nil -} - -// Select a new Host randomly and add a sticky session cookie -func selectNewHostWithCookieHashSelection(pool []*Upstream, w http.ResponseWriter, cookieSecret string, cookieName string) *Upstream { - randomHost := selectRandomHost(pool) - - if randomHost != nil { - // Hash (HMAC with some key for privacy) the upstream.Dial string as the cookie value - sha, err := hashCookie(cookieSecret, randomHost.Dial) - if err == nil { - // write the cookie. - http.SetCookie(w, &http.Cookie{Name: cookieName, Value: sha, Path: "/", Secure: false}) + for nesting := d.Nesting(); d.NextBlock(nesting); { + switch d.Val() { + case "fallback": + if !d.NextArg() { + return d.ArgErr() + } + if s.FallbackRaw != nil { + return d.Err("fallback selection policy already specified") + } + mod, err := loadFallbackPolicy(d) + if err != nil { + return err + } + s.FallbackRaw = mod + default: + return d.Errf("unrecognized option '%s'", d.Val()) } } - return randomHost + return nil } // hashCookie hashes (HMAC 256) some data with the secret @@ -512,6 +789,9 @@ func leastRequests(upstreams []*Upstream) *Upstream { if len(best) == 0 { return nil } + if len(best) == 1 { + return best[0] + } return best[weakrand.Intn(len(best))] //nolint:gosec } @@ -544,20 +824,40 @@ func hash(s string) uint32 { return h.Sum32() } +func loadFallbackPolicy(d *caddyfile.Dispenser) (json.RawMessage, error) { + name := d.Val() + modID := "http.reverse_proxy.selection_policies." + name + unm, err := caddyfile.UnmarshalModule(d, modID) + if err != nil { + return nil, err + } + sel, ok := unm.(Selector) + if !ok { + return nil, d.Errf("module %s (%T) is not a reverseproxy.Selector", modID, unm) + } + return caddyconfig.JSONModuleObject(sel, "policy", name, nil), nil +} + // Interface guards var ( _ Selector = (*RandomSelection)(nil) _ Selector = (*RandomChoiceSelection)(nil) _ Selector = (*LeastConnSelection)(nil) _ Selector = (*RoundRobinSelection)(nil) + _ Selector = (*WeightedRoundRobinSelection)(nil) _ Selector = (*FirstSelection)(nil) _ Selector = (*IPHashSelection)(nil) + _ Selector = (*ClientIPHashSelection)(nil) _ Selector = (*URIHashSelection)(nil) + _ Selector = (*QueryHashSelection)(nil) _ Selector = (*HeaderHashSelection)(nil) _ Selector = (*CookieHashSelection)(nil) - _ caddy.Validator = (*RandomChoiceSelection)(nil) + _ caddy.Validator = (*RandomChoiceSelection)(nil) + _ caddy.Provisioner = (*RandomChoiceSelection)(nil) + _ caddy.Provisioner = (*WeightedRoundRobinSelection)(nil) _ caddyfile.Unmarshaler = (*RandomChoiceSelection)(nil) + _ caddyfile.Unmarshaler = (*WeightedRoundRobinSelection)(nil) ) diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go index 546a60d..dc613a5 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go @@ -15,9 +15,14 @@ package reverseproxy import ( + "context" "net/http" "net/http/httptest" "testing" + + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" ) func testPool() UpstreamPool { @@ -30,7 +35,7 @@ func testPool() UpstreamPool { func TestRoundRobinPolicy(t *testing.T) { pool := testPool() - rrPolicy := new(RoundRobinSelection) + rrPolicy := RoundRobinSelection{} req, _ := http.NewRequest("GET", "/", nil) h := rrPolicy.Select(pool, req, nil) @@ -69,9 +74,66 @@ func TestRoundRobinPolicy(t *testing.T) { } } +func TestWeightedRoundRobinPolicy(t *testing.T) { + pool := testPool() + wrrPolicy := WeightedRoundRobinSelection{ + Weights: []int{3, 2, 1}, + totalWeight: 6, + } + req, _ := http.NewRequest("GET", "/", nil) + + h := wrrPolicy.Select(pool, req, nil) + if h != pool[0] { + t.Error("Expected first weighted round robin host to be first host in the pool.") + } + h = wrrPolicy.Select(pool, req, nil) + if h != pool[0] { + t.Error("Expected second weighted round robin host to be first host in the pool.") + } + // Third selected host is 1, because counter starts at 0 + // and increments before host is selected + h = wrrPolicy.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected third weighted round robin host to be second host in the pool.") + } + h = wrrPolicy.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected fourth weighted round robin host to be second host in the pool.") + } + h = wrrPolicy.Select(pool, req, nil) + if h != pool[2] { + t.Error("Expected fifth weighted round robin host to be third host in the pool.") + } + h = wrrPolicy.Select(pool, req, nil) + if h != pool[0] { + t.Error("Expected sixth weighted round robin host to be first host in the pool.") + } + + // mark host as down + pool[0].setHealthy(false) + h = wrrPolicy.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected to skip down host.") + } + // mark host as up + pool[0].setHealthy(true) + + h = wrrPolicy.Select(pool, req, nil) + if h != pool[0] { + t.Error("Expected to select first host on availablity.") + } + // mark host as full + pool[1].countRequest(1) + pool[1].MaxRequests = 1 + h = wrrPolicy.Select(pool, req, nil) + if h != pool[2] { + t.Error("Expected to skip full host.") + } +} + func TestLeastConnPolicy(t *testing.T) { pool := testPool() - lcPolicy := new(LeastConnSelection) + lcPolicy := LeastConnSelection{} req, _ := http.NewRequest("GET", "/", nil) pool[0].countRequest(10) @@ -89,7 +151,7 @@ func TestLeastConnPolicy(t *testing.T) { func TestIPHashPolicy(t *testing.T) { pool := testPool() - ipHash := new(IPHashSelection) + ipHash := IPHashSelection{} req, _ := http.NewRequest("GET", "/", nil) // We should be able to predict where every request is routed. @@ -229,9 +291,152 @@ func TestIPHashPolicy(t *testing.T) { } } +func TestClientIPHashPolicy(t *testing.T) { + pool := testPool() + ipHash := ClientIPHashSelection{} + req, _ := http.NewRequest("GET", "/", nil) + req = req.WithContext(context.WithValue(req.Context(), caddyhttp.VarsCtxKey, make(map[string]any))) + + // We should be able to predict where every request is routed. + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.1:80") + h := ipHash.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.2:80") + h = ipHash.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.3:80") + h = ipHash.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4:80") + h = ipHash.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + + // we should get the same results without a port + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.1") + h = ipHash.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.2") + h = ipHash.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.3") + h = ipHash.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4") + h = ipHash.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + + // we should get a healthy host if the original host is unhealthy and a + // healthy host is available + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4") + pool[1].setHealthy(false) + h = ipHash.Select(pool, req, nil) + if h != pool[0] { + t.Error("Expected ip hash policy host to be the first host.") + } + + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.2") + h = ipHash.Select(pool, req, nil) + if h != pool[0] { + t.Error("Expected ip hash policy host to be the first host.") + } + pool[1].setHealthy(true) + + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.3") + pool[2].setHealthy(false) + h = ipHash.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4") + h = ipHash.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + + // We should be able to resize the host pool and still be able to predict + // where a req will be routed with the same IP's used above + pool = UpstreamPool{ + {Host: new(Host), Dial: "0.0.0.2"}, + {Host: new(Host), Dial: "0.0.0.3"}, + } + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.1:80") + h = ipHash.Select(pool, req, nil) + if h != pool[0] { + t.Error("Expected ip hash policy host to be the first host.") + } + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.2:80") + h = ipHash.Select(pool, req, nil) + if h != pool[0] { + t.Error("Expected ip hash policy host to be the first host.") + } + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.3:80") + h = ipHash.Select(pool, req, nil) + if h != pool[0] { + t.Error("Expected ip hash policy host to be the first host.") + } + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4:80") + h = ipHash.Select(pool, req, nil) + if h != pool[0] { + t.Error("Expected ip hash policy host to be the first host.") + } + + // We should get nil when there are no healthy hosts + pool[0].setHealthy(false) + pool[1].setHealthy(false) + h = ipHash.Select(pool, req, nil) + if h != nil { + t.Error("Expected ip hash policy host to be nil.") + } + + // Reproduce #4135 + pool = UpstreamPool{ + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, + } + pool[0].setHealthy(false) + pool[1].setHealthy(false) + pool[2].setHealthy(false) + pool[3].setHealthy(false) + pool[4].setHealthy(false) + pool[5].setHealthy(false) + pool[6].setHealthy(false) + pool[7].setHealthy(false) + pool[8].setHealthy(true) + + // We should get a result back when there is one healthy host left. + h = ipHash.Select(pool, req, nil) + if h == nil { + // If it is nil, it means we missed a host even though one is available + t.Error("Expected ip hash policy host to not be nil, but it is nil.") + } +} + func TestFirstPolicy(t *testing.T) { pool := testPool() - firstPolicy := new(FirstSelection) + firstPolicy := FirstSelection{} req := httptest.NewRequest(http.MethodGet, "/", nil) h := firstPolicy.Select(pool, req, nil) @@ -246,9 +451,85 @@ func TestFirstPolicy(t *testing.T) { } } +func TestQueryHashPolicy(t *testing.T) { + ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) + defer cancel() + queryPolicy := QueryHashSelection{Key: "foo"} + if err := queryPolicy.Provision(ctx); err != nil { + t.Errorf("Provision error: %v", err) + t.FailNow() + } + + pool := testPool() + + request := httptest.NewRequest(http.MethodGet, "/?foo=1", nil) + h := queryPolicy.Select(pool, request, nil) + if h != pool[0] { + t.Error("Expected query policy host to be the first host.") + } + + request = httptest.NewRequest(http.MethodGet, "/?foo=100000", nil) + h = queryPolicy.Select(pool, request, nil) + if h != pool[0] { + t.Error("Expected query policy host to be the first host.") + } + + request = httptest.NewRequest(http.MethodGet, "/?foo=1", nil) + pool[0].setHealthy(false) + h = queryPolicy.Select(pool, request, nil) + if h != pool[1] { + t.Error("Expected query policy host to be the second host.") + } + + request = httptest.NewRequest(http.MethodGet, "/?foo=100000", nil) + h = queryPolicy.Select(pool, request, nil) + if h != pool[2] { + t.Error("Expected query policy host to be the third host.") + } + + // We should be able to resize the host pool and still be able to predict + // where a request will be routed with the same query used above + pool = UpstreamPool{ + {Host: new(Host)}, + {Host: new(Host)}, + } + + request = httptest.NewRequest(http.MethodGet, "/?foo=1", nil) + h = queryPolicy.Select(pool, request, nil) + if h != pool[0] { + t.Error("Expected query policy host to be the first host.") + } + + pool[0].setHealthy(false) + h = queryPolicy.Select(pool, request, nil) + if h != pool[1] { + t.Error("Expected query policy host to be the second host.") + } + + request = httptest.NewRequest(http.MethodGet, "/?foo=4", nil) + h = queryPolicy.Select(pool, request, nil) + if h != pool[1] { + t.Error("Expected query policy host to be the second host.") + } + + pool[0].setHealthy(false) + pool[1].setHealthy(false) + h = queryPolicy.Select(pool, request, nil) + if h != nil { + t.Error("Expected query policy policy host to be nil.") + } + + request = httptest.NewRequest(http.MethodGet, "/?foo=aa11&foo=bb22", nil) + pool = testPool() + h = queryPolicy.Select(pool, request, nil) + if h != pool[0] { + t.Error("Expected query policy host to be the first host.") + } +} + func TestURIHashPolicy(t *testing.T) { pool := testPool() - uriPolicy := new(URIHashSelection) + uriPolicy := URIHashSelection{} request := httptest.NewRequest(http.MethodGet, "/test", nil) h := uriPolicy.Select(pool, request, nil) @@ -337,8 +618,7 @@ func TestRandomChoicePolicy(t *testing.T) { pool[2].countRequest(30) request := httptest.NewRequest(http.MethodGet, "/test", nil) - randomChoicePolicy := new(RandomChoiceSelection) - randomChoicePolicy.Choose = 2 + randomChoicePolicy := RandomChoiceSelection{Choose: 2} h := randomChoicePolicy.Select(pool, request, nil) @@ -353,6 +633,14 @@ func TestRandomChoicePolicy(t *testing.T) { } func TestCookieHashPolicy(t *testing.T) { + ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) + defer cancel() + cookieHashPolicy := CookieHashSelection{} + if err := cookieHashPolicy.Provision(ctx); err != nil { + t.Errorf("Provision error: %v", err) + t.FailNow() + } + pool := testPool() pool[0].Dial = "localhost:8080" pool[1].Dial = "localhost:8081" @@ -362,7 +650,7 @@ func TestCookieHashPolicy(t *testing.T) { pool[2].setHealthy(false) request := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() - cookieHashPolicy := new(CookieHashSelection) + h := cookieHashPolicy.Select(pool, request, w) cookieServer1 := w.Result().Cookies()[0] if cookieServer1 == nil { @@ -399,3 +687,59 @@ func TestCookieHashPolicy(t *testing.T) { t.Error("Expected cookieHashPolicy to set a new cookie.") } } + +func TestCookieHashPolicyWithFirstFallback(t *testing.T) { + ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) + defer cancel() + cookieHashPolicy := CookieHashSelection{ + FallbackRaw: caddyconfig.JSONModuleObject(FirstSelection{}, "policy", "first", nil), + } + if err := cookieHashPolicy.Provision(ctx); err != nil { + t.Errorf("Provision error: %v", err) + t.FailNow() + } + + pool := testPool() + pool[0].Dial = "localhost:8080" + pool[1].Dial = "localhost:8081" + pool[2].Dial = "localhost:8082" + pool[0].setHealthy(true) + pool[1].setHealthy(true) + pool[2].setHealthy(true) + request := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + + h := cookieHashPolicy.Select(pool, request, w) + cookieServer1 := w.Result().Cookies()[0] + if cookieServer1 == nil { + t.Fatal("cookieHashPolicy should set a cookie") + } + if cookieServer1.Name != "lb" { + t.Error("cookieHashPolicy should set a cookie with name lb") + } + if h != pool[0] { + t.Errorf("Expected cookieHashPolicy host to be the first only available host, got %s", h) + } + request = httptest.NewRequest(http.MethodGet, "/test", nil) + w = httptest.NewRecorder() + request.AddCookie(cookieServer1) + h = cookieHashPolicy.Select(pool, request, w) + if h != pool[0] { + t.Errorf("Expected cookieHashPolicy host to stick to the first host (matching cookie), got %s", h) + } + s := w.Result().Cookies() + if len(s) != 0 { + t.Error("Expected cookieHashPolicy to not set a new cookie.") + } + pool[0].setHealthy(false) + request = httptest.NewRequest(http.MethodGet, "/test", nil) + w = httptest.NewRecorder() + request.AddCookie(cookieServer1) + h = cookieHashPolicy.Select(pool, request, w) + if h != pool[1] { + t.Errorf("Expected cookieHashPolicy to select the next first available host, got %s", h) + } + if w.Result().Cookies() == nil { + t.Error("Expected cookieHashPolicy to set a new cookie.") + } +} diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 1db107a..155a1df 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -20,6 +20,8 @@ package reverseproxy import ( "context" + "errors" + "fmt" "io" weakrand "math/rand" "mime" @@ -32,32 +34,46 @@ import ( "golang.org/x/net/http/httpguts" ) -func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) { +func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) { reqUpType := upgradeType(req.Header) resUpType := upgradeType(res.Header) // Taken from https://github.com/golang/go/commit/5c489514bc5e61ad9b5b07bd7d8ec65d66a0512a // We know reqUpType is ASCII, it's checked by the caller. if !asciiIsPrint(resUpType) { - h.logger.Debug("backend tried to switch to invalid protocol", + logger.Debug("backend tried to switch to invalid protocol", zap.String("backend_upgrade", resUpType)) return } if !asciiEqualFold(reqUpType, resUpType) { - h.logger.Debug("backend tried to switch to unexpected protocol via Upgrade header", + logger.Debug("backend tried to switch to unexpected protocol via Upgrade header", zap.String("backend_upgrade", resUpType), zap.String("requested_upgrade", reqUpType)) return } - hj, ok := rw.(http.Hijacker) + backConn, ok := res.Body.(io.ReadWriteCloser) if !ok { + logger.Error("internal error: 101 switching protocols response with non-writable body") + return + } + + // write header first, response headers should not be counted in size + // like the rest of handler chain. + copyHeader(rw.Header(), res.Header) + rw.WriteHeader(res.StatusCode) + + logger.Debug("upgrading connection") + + //nolint:bodyclose + conn, brw, hijackErr := http.NewResponseController(rw).Hijack() + if errors.Is(hijackErr, http.ErrNotSupported) { h.logger.Sugar().Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw) return } - backConn, ok := res.Body.(io.ReadWriteCloser) - if !ok { - h.logger.Error("internal error: 101 switching protocols response with non-writable body") + + if hijackErr != nil { + h.logger.Error("hijack failed on protocol switch", zap.Error(hijackErr)) return } @@ -74,18 +90,6 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite }() defer close(backConnCloseCh) - // write header first, response headers should not be counted in size - // like the rest of handler chain. - copyHeader(rw.Header(), res.Header) - rw.WriteHeader(res.StatusCode) - - logger.Debug("upgrading connection") - conn, brw, err := hj.Hijack() - if err != nil { - h.logger.Error("hijack failed on protocol switch", zap.Error(err)) - return - } - start := time.Now() defer func() { conn.Close() @@ -93,7 +97,7 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite }() if err := brw.Flush(); err != nil { - h.logger.Debug("response flush", zap.Error(err)) + logger.Debug("response flush", zap.Error(err)) return } @@ -119,10 +123,23 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite spc := switchProtocolCopier{user: conn, backend: backConn} + // setup the timeout if requested + var timeoutc <-chan time.Time + if h.StreamTimeout > 0 { + timer := time.NewTimer(time.Duration(h.StreamTimeout)) + defer timer.Stop() + timeoutc = timer.C + } + errc := make(chan error, 1) go spc.copyToBackend(errc) go spc.copyFromBackend(errc) - <-errc + select { + case err := <-errc: + logger.Debug("streaming error", zap.Error(err)) + case time := <-timeoutc: + logger.Debug("stream timed out", zap.Time("timeout", time)) + } } // flushInterval returns the p.FlushInterval value, conditionally @@ -167,38 +184,58 @@ func (h Handler) isBidirectionalStream(req *http.Request, res *http.Response) bo (ae == "identity" || ae == "") } -func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { +func (h Handler) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration, logger *zap.Logger) error { + var w io.Writer = dst + if flushInterval != 0 { - if wf, ok := dst.(writeFlusher); ok { - mlw := &maxLatencyWriter{ - dst: wf, - latency: flushInterval, - } - defer mlw.stop() + var mlwLogger *zap.Logger + if h.VerboseLogs { + mlwLogger = logger.Named("max_latency_writer") + } else { + mlwLogger = zap.NewNop() + } + mlw := &maxLatencyWriter{ + dst: dst, + //nolint:bodyclose + flush: http.NewResponseController(dst).Flush, + latency: flushInterval, + logger: mlwLogger, + } + defer mlw.stop() - // set up initial timer so headers get flushed even if body writes are delayed - mlw.flushPending = true - mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) + // set up initial timer so headers get flushed even if body writes are delayed + mlw.flushPending = true + mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) - dst = mlw - } + w = mlw } buf := streamingBufPool.Get().(*[]byte) defer streamingBufPool.Put(buf) - _, err := h.copyBuffer(dst, src, *buf) + + var copyLogger *zap.Logger + if h.VerboseLogs { + copyLogger = logger + } else { + copyLogger = zap.NewNop() + } + + _, err := h.copyBuffer(w, src, *buf, copyLogger) return err } // copyBuffer returns any write errors or non-EOF read errors, and the amount // of bytes written. -func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { +func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *zap.Logger) (int64, error) { if len(buf) == 0 { buf = make([]byte, defaultBufferSize) } var written int64 for { + logger.Debug("waiting to read from upstream") nr, rerr := src.Read(buf) + logger := logger.With(zap.Int("read", nr)) + logger.Debug("read from upstream", zap.Error(rerr)) if rerr != nil && rerr != io.EOF && rerr != context.Canceled { // TODO: this could be useful to know (indeed, it revealed an error in our // fastcgi PoC earlier; but it's this single error report here that necessitates @@ -210,12 +247,17 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, er h.logger.Error("reading from backend", zap.Error(rerr)) } if nr > 0 { + logger.Debug("writing to downstream") nw, werr := dst.Write(buf[:nr]) if nw > 0 { written += int64(nw) } + logger.Debug("wrote to downstream", + zap.Int("written", nw), + zap.Int64("written_total", written), + zap.Error(werr)) if werr != nil { - return written, werr + return written, fmt.Errorf("writing: %w", werr) } if nr != nw { return written, io.ErrShortWrite @@ -223,9 +265,9 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, er } if rerr != nil { if rerr == io.EOF { - rerr = nil + return written, nil } - return written, rerr + return written, fmt.Errorf("reading: %w", rerr) } } } @@ -242,10 +284,70 @@ func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func return func() { h.connectionsMu.Lock() delete(h.connections, conn) + // if there is no connection left before the connections close timer fires + if len(h.connections) == 0 && h.connectionsCloseTimer != nil { + // we release the timer that holds the reference to Handler + if (*h.connectionsCloseTimer).Stop() { + h.logger.Debug("stopped streaming connections close timer - all connections are already closed") + } + h.connectionsCloseTimer = nil + } h.connectionsMu.Unlock() } } +// closeConnections immediately closes all hijacked connections (both to client and backend). +func (h *Handler) closeConnections() error { + var err error + h.connectionsMu.Lock() + defer h.connectionsMu.Unlock() + + for _, oc := range h.connections { + if oc.gracefulClose != nil { + // this is potentially blocking while we have the lock on the connections + // map, but that should be OK since the server has in theory shut down + // and we are no longer using the connections map + gracefulErr := oc.gracefulClose() + if gracefulErr != nil && err == nil { + err = gracefulErr + } + } + closeErr := oc.conn.Close() + if closeErr != nil && err == nil { + err = closeErr + } + } + return err +} + +// cleanupConnections closes hijacked connections. +// Depending on the value of StreamCloseDelay it does that either immediately +// or sets up a timer that will do that later. +func (h *Handler) cleanupConnections() error { + if h.StreamCloseDelay == 0 { + return h.closeConnections() + } + + h.connectionsMu.Lock() + defer h.connectionsMu.Unlock() + // the handler is shut down, no new connection can appear, + // so we can skip setting up the timer when there are no connections + if len(h.connections) > 0 { + delay := time.Duration(h.StreamCloseDelay) + h.connectionsCloseTimer = time.AfterFunc(delay, func() { + h.logger.Debug("closing streaming connections after delay", + zap.Duration("delay", delay)) + err := h.closeConnections() + if err != nil { + h.logger.Error("failed to closed connections after delay", + zap.Error(err), + zap.Duration("delay", delay)) + } + }) + } + return nil +} + // writeCloseControl sends a best-effort Close control message to the given // WebSocket connection. Thanks to @pascaldekloe who provided inspiration // from his simple implementation of this I was able to learn from at: @@ -365,29 +467,30 @@ type openConnection struct { gracefulClose func() error } -type writeFlusher interface { - io.Writer - http.Flusher -} - type maxLatencyWriter struct { - dst writeFlusher + dst io.Writer + flush func() error latency time.Duration // non-zero; negative means to flush immediately mu sync.Mutex // protects t, flushPending, and dst.Flush t *time.Timer flushPending bool + logger *zap.Logger } func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { m.mu.Lock() defer m.mu.Unlock() n, err = m.dst.Write(p) + m.logger.Debug("wrote bytes", zap.Int("n", n), zap.Error(err)) if m.latency < 0 { - m.dst.Flush() + m.logger.Debug("flushing immediately") + //nolint:errcheck + m.flush() return } if m.flushPending { + m.logger.Debug("delayed flush already pending") return } if m.t == nil { @@ -395,6 +498,7 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { } else { m.t.Reset(m.latency) } + m.logger.Debug("timer set for delayed flush", zap.Duration("duration", m.latency)) m.flushPending = true return } @@ -403,9 +507,12 @@ func (m *maxLatencyWriter) delayedFlush() { m.mu.Lock() defer m.mu.Unlock() if !m.flushPending { // if stop was called but AfterFunc already started this goroutine + m.logger.Debug("delayed flush is not pending") return } - m.dst.Flush() + m.logger.Debug("delayed flush") + //nolint:errcheck + m.flush() m.flushPending = false } @@ -445,5 +552,7 @@ var streamingBufPool = sync.Pool{ }, } -const defaultBufferSize = 32 * 1024 -const wordSize = int(unsafe.Sizeof(uintptr(0))) +const ( + defaultBufferSize = 32 * 1024 + wordSize = int(unsafe.Sizeof(uintptr(0))) +) diff --git a/modules/caddyhttp/reverseproxy/streaming_test.go b/modules/caddyhttp/reverseproxy/streaming_test.go index 4ed1f1e..3f6da2f 100644 --- a/modules/caddyhttp/reverseproxy/streaming_test.go +++ b/modules/caddyhttp/reverseproxy/streaming_test.go @@ -2,8 +2,11 @@ package reverseproxy import ( "bytes" + "net/http/httptest" "strings" "testing" + + "github.com/caddyserver/caddy/v2" ) func TestHandlerCopyResponse(t *testing.T) { @@ -13,12 +16,15 @@ func TestHandlerCopyResponse(t *testing.T) { strings.Repeat("a", defaultBufferSize), strings.Repeat("123456789 123456789 123456789 12", 3000), } + dst := bytes.NewBuffer(nil) + recorder := httptest.NewRecorder() + recorder.Body = dst for _, d := range testdata { src := bytes.NewBuffer([]byte(d)) dst.Reset() - err := h.copyResponse(dst, src, 0) + err := h.copyResponse(recorder, src, 0, caddy.Log()) if err != nil { t.Errorf("failed with error: %v", err) } diff --git a/modules/caddyhttp/reverseproxy/upstreams.go b/modules/caddyhttp/reverseproxy/upstreams.go index 7a90016..2d21a5c 100644 --- a/modules/caddyhttp/reverseproxy/upstreams.go +++ b/modules/caddyhttp/reverseproxy/upstreams.go @@ -8,12 +8,12 @@ import ( "net" "net/http" "strconv" - "strings" "sync" "time" - "github.com/caddyserver/caddy/v2" "go.uber.org/zap" + + "github.com/caddyserver/caddy/v2" ) func init() { @@ -114,7 +114,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) { cached := srvs[suAddr] srvsMu.RUnlock() if cached.isFresh() { - return cached.upstreams, nil + return allNew(cached.upstreams), nil } // otherwise, obtain a write-lock to update the cached value @@ -126,7 +126,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) { // have refreshed it in the meantime before we re-obtained our lock cached = srvs[suAddr] if cached.isFresh() { - return cached.upstreams, nil + return allNew(cached.upstreams), nil } su.logger.Debug("refreshing SRV upstreams", @@ -145,7 +145,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) { su.logger.Warn("SRV records filtered", zap.Error(err)) } - upstreams := make([]*Upstream, len(records)) + upstreams := make([]Upstream, len(records)) for i, rec := range records { su.logger.Debug("discovered SRV record", zap.String("target", rec.Target), @@ -153,7 +153,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) { zap.Uint16("priority", rec.Priority), zap.Uint16("weight", rec.Weight)) addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port))) - upstreams[i] = &Upstream{Dial: addr} + upstreams[i] = Upstream{Dial: addr} } // before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full @@ -170,7 +170,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) { upstreams: upstreams, } - return upstreams, nil + return allNew(upstreams), nil } func (su SRVUpstreams) String() string { @@ -206,13 +206,18 @@ func (SRVUpstreams) formattedAddr(service, proto, name string) string { type srvLookup struct { srvUpstreams SRVUpstreams freshness time.Time - upstreams []*Upstream + upstreams []Upstream } func (sl srvLookup) isFresh() bool { return time.Since(sl.freshness) < time.Duration(sl.srvUpstreams.Refresh) } +type IPVersions struct { + IPv4 *bool `json:"ipv4,omitempty"` + IPv6 *bool `json:"ipv6,omitempty"` +} + // AUpstreams provides upstreams from A/AAAA lookups. // Results are cached and refreshed at the configured // refresh interval. @@ -240,7 +245,14 @@ type AUpstreams struct { // A negative value disables this. FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"` + // The IP versions to resolve for. By default, both + // "ipv4" and "ipv6" will be enabled, which + // correspond to A and AAAA records respectively. + Versions *IPVersions `json:"versions,omitempty"` + resolver *net.Resolver + + logger *zap.Logger } // CaddyModule returns the Caddy module information. @@ -251,7 +263,8 @@ func (AUpstreams) CaddyModule() caddy.ModuleInfo { } } -func (au *AUpstreams) Provision(_ caddy.Context) error { +func (au *AUpstreams) Provision(ctx caddy.Context) error { + au.logger = ctx.Logger() if au.Refresh == 0 { au.Refresh = caddy.Duration(time.Minute) } @@ -286,14 +299,36 @@ func (au *AUpstreams) Provision(_ caddy.Context) error { func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) { repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) - auStr := repl.ReplaceAll(au.String(), "") + + resolveIpv4 := au.Versions == nil || au.Versions.IPv4 == nil || *au.Versions.IPv4 + resolveIpv6 := au.Versions == nil || au.Versions.IPv6 == nil || *au.Versions.IPv6 + + // Map ipVersion early, so we can use it as part of the cache-key. + // This should be fairly inexpensive and comes and the upside of + // allowing the same dynamic upstream (name + port combination) + // to be used multiple times with different ip versions. + // + // It also forced a cache-miss if a previously cached dynamic + // upstream changes its ip version, e.g. after a config reload, + // while keeping the cache-invalidation as simple as it currently is. + var ipVersion string + switch { + case resolveIpv4 && !resolveIpv6: + ipVersion = "ip4" + case !resolveIpv4 && resolveIpv6: + ipVersion = "ip6" + default: + ipVersion = "ip" + } + + auStr := repl.ReplaceAll(au.String()+ipVersion, "") // first, use a cheap read-lock to return a cached result quickly aAaaaMu.RLock() cached := aAaaa[auStr] aAaaaMu.RUnlock() if cached.isFresh() { - return cached.upstreams, nil + return allNew(cached.upstreams), nil } // otherwise, obtain a write-lock to update the cached value @@ -305,26 +340,33 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) { // have refreshed it in the meantime before we re-obtained our lock cached = aAaaa[auStr] if cached.isFresh() { - return cached.upstreams, nil + return allNew(cached.upstreams), nil } name := repl.ReplaceAll(au.Name, "") port := repl.ReplaceAll(au.Port, "") - ips, err := au.resolver.LookupIPAddr(r.Context(), name) + au.logger.Debug("refreshing A upstreams", + zap.String("version", ipVersion), + zap.String("name", name), + zap.String("port", port)) + + ips, err := au.resolver.LookupIP(r.Context(), ipVersion, name) if err != nil { return nil, err } - upstreams := make([]*Upstream, len(ips)) + upstreams := make([]Upstream, len(ips)) for i, ip := range ips { - upstreams[i] = &Upstream{ + au.logger.Debug("discovered A record", + zap.String("ip", ip.String())) + upstreams[i] = Upstream{ Dial: net.JoinHostPort(ip.String(), port), } } // before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full - if cached.freshness.IsZero() && len(srvs) >= 100 { + if cached.freshness.IsZero() && len(aAaaa) >= 100 { for randomKey := range aAaaa { delete(aAaaa, randomKey) break @@ -337,7 +379,7 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) { upstreams: upstreams, } - return upstreams, nil + return allNew(upstreams), nil } func (au AUpstreams) String() string { return net.JoinHostPort(au.Name, au.Port) } @@ -345,7 +387,7 @@ func (au AUpstreams) String() string { return net.JoinHostPort(au.Name, au.Port) type aLookup struct { aUpstreams AUpstreams freshness time.Time - upstreams []*Upstream + upstreams []Upstream } func (al aLookup) isFresh() bool { @@ -439,16 +481,9 @@ type UpstreamResolver struct { // and ensures they're ready to be used. func (u *UpstreamResolver) ParseAddresses() error { for _, v := range u.Addresses { - addr, err := caddy.ParseNetworkAddress(v) + addr, err := caddy.ParseNetworkAddressWithDefaults(v, "udp", 53) if err != nil { - // If a port wasn't specified for the resolver, - // try defaulting to 53 and parse again - if strings.Contains(err.Error(), "missing port in address") { - addr, err = caddy.ParseNetworkAddress(v + ":53") - } - if err != nil { - return err - } + return err } if addr.PortRangeSize() != 1 { return fmt.Errorf("resolver address must have exactly one address; cannot call %v", addr) @@ -458,6 +493,14 @@ func (u *UpstreamResolver) ParseAddresses() error { return nil } +func allNew(upstreams []Upstream) []*Upstream { + results := make([]*Upstream, len(upstreams)) + for i := range upstreams { + results[i] = &Upstream{Dial: upstreams[i].Dial} + } + return results +} + var ( srvs = make(map[string]srvLookup) srvsMu sync.RWMutex |