summaryrefslogtreecommitdiff
path: root/modules/caddyhttp/reverseproxy
diff options
context:
space:
mode:
Diffstat (limited to 'modules/caddyhttp/reverseproxy')
-rw-r--r--modules/caddyhttp/reverseproxy/addresses.go103
-rw-r--r--modules/caddyhttp/reverseproxy/addresses_test.go48
-rw-r--r--modules/caddyhttp/reverseproxy/caddyfile.go1107
-rw-r--r--modules/caddyhttp/reverseproxy/command.go165
-rw-r--r--modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go43
-rw-r--r--modules/caddyhttp/reverseproxy/fastcgi/client.go3
-rw-r--r--modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go16
-rw-r--r--modules/caddyhttp/reverseproxy/forwardauth/caddyfile.go14
-rw-r--r--modules/caddyhttp/reverseproxy/healthchecks.go133
-rw-r--r--modules/caddyhttp/reverseproxy/hosts.go67
-rw-r--r--modules/caddyhttp/reverseproxy/httptransport.go95
-rw-r--r--modules/caddyhttp/reverseproxy/metrics.go2
-rw-r--r--modules/caddyhttp/reverseproxy/reverseproxy.go275
-rw-r--r--modules/caddyhttp/reverseproxy/selectionpolicies.go360
-rw-r--r--modules/caddyhttp/reverseproxy/selectionpolicies_test.go360
-rw-r--r--modules/caddyhttp/reverseproxy/streaming.go205
-rw-r--r--modules/caddyhttp/reverseproxy/streaming_test.go8
-rw-r--r--modules/caddyhttp/reverseproxy/upstreams.go97
18 files changed, 2131 insertions, 970 deletions
diff --git a/modules/caddyhttp/reverseproxy/addresses.go b/modules/caddyhttp/reverseproxy/addresses.go
index 8152108..82c1c79 100644
--- a/modules/caddyhttp/reverseproxy/addresses.go
+++ b/modules/caddyhttp/reverseproxy/addresses.go
@@ -23,11 +23,46 @@ import (
"github.com/caddyserver/caddy/v2"
)
+type parsedAddr struct {
+ network, scheme, host, port string
+ valid bool
+}
+
+func (p parsedAddr) dialAddr() string {
+ if !p.valid {
+ return ""
+ }
+ // for simplest possible config, we only need to include
+ // the network portion if the user specified one
+ if p.network != "" {
+ return caddy.JoinNetworkAddress(p.network, p.host, p.port)
+ }
+
+ // if the host is a placeholder, then we don't want to join with an empty port,
+ // because that would just append an extra ':' at the end of the address.
+ if p.port == "" && strings.Contains(p.host, "{") {
+ return p.host
+ }
+ return net.JoinHostPort(p.host, p.port)
+}
+
+func (p parsedAddr) rangedPort() bool {
+ return strings.Contains(p.port, "-")
+}
+
+func (p parsedAddr) replaceablePort() bool {
+ return strings.Contains(p.port, "{") && strings.Contains(p.port, "}")
+}
+
+func (p parsedAddr) isUnix() bool {
+ return caddy.IsUnixNetwork(p.network)
+}
+
// parseUpstreamDialAddress parses configuration inputs for
// the dial address, including support for a scheme in front
// as a shortcut for the port number, and a network type,
// for example 'unix' to dial a unix socket.
-func parseUpstreamDialAddress(upstreamAddr string) (string, string, error) {
+func parseUpstreamDialAddress(upstreamAddr string) (parsedAddr, error) {
var network, scheme, host, port string
if strings.Contains(upstreamAddr, "://") {
@@ -35,46 +70,65 @@ func parseUpstreamDialAddress(upstreamAddr string) (string, string, error) {
// so we return a more user-friendly error message instead
// to explain what to do instead
if strings.Contains(upstreamAddr, "{") {
- return "", "", fmt.Errorf("due to parsing difficulties, placeholders are not allowed when an upstream address contains a scheme")
+ return parsedAddr{}, fmt.Errorf("due to parsing difficulties, placeholders are not allowed when an upstream address contains a scheme")
}
toURL, err := url.Parse(upstreamAddr)
if err != nil {
- return "", "", fmt.Errorf("parsing upstream URL: %v", err)
+ // if the error seems to be due to a port range,
+ // try to replace the port range with a dummy
+ // single port so that url.Parse() will succeed
+ if strings.Contains(err.Error(), "invalid port") && strings.Contains(err.Error(), "-") {
+ index := strings.LastIndex(upstreamAddr, ":")
+ if index == -1 {
+ return parsedAddr{}, fmt.Errorf("parsing upstream URL: %v", err)
+ }
+ portRange := upstreamAddr[index+1:]
+ if strings.Count(portRange, "-") != 1 {
+ return parsedAddr{}, fmt.Errorf("parsing upstream URL: parse \"%v\": port range invalid: %v", upstreamAddr, portRange)
+ }
+ toURL, err = url.Parse(strings.ReplaceAll(upstreamAddr, portRange, "0"))
+ if err != nil {
+ return parsedAddr{}, fmt.Errorf("parsing upstream URL: %v", err)
+ }
+ port = portRange
+ } else {
+ return parsedAddr{}, fmt.Errorf("parsing upstream URL: %v", err)
+ }
+ }
+ if port == "" {
+ port = toURL.Port()
}
// there is currently no way to perform a URL rewrite between choosing
// a backend and proxying to it, so we cannot allow extra components
// in backend URLs
if toURL.Path != "" || toURL.RawQuery != "" || toURL.Fragment != "" {
- return "", "", fmt.Errorf("for now, URLs for proxy upstreams only support scheme, host, and port components")
+ return parsedAddr{}, fmt.Errorf("for now, URLs for proxy upstreams only support scheme, host, and port components")
}
// ensure the port and scheme aren't in conflict
- urlPort := toURL.Port()
- if toURL.Scheme == "http" && urlPort == "443" {
- return "", "", fmt.Errorf("upstream address has conflicting scheme (http://) and port (:443, the HTTPS port)")
+ if toURL.Scheme == "http" && port == "443" {
+ return parsedAddr{}, fmt.Errorf("upstream address has conflicting scheme (http://) and port (:443, the HTTPS port)")
}
- if toURL.Scheme == "https" && urlPort == "80" {
- return "", "", fmt.Errorf("upstream address has conflicting scheme (https://) and port (:80, the HTTP port)")
+ if toURL.Scheme == "https" && port == "80" {
+ return parsedAddr{}, fmt.Errorf("upstream address has conflicting scheme (https://) and port (:80, the HTTP port)")
}
- if toURL.Scheme == "h2c" && urlPort == "443" {
- return "", "", fmt.Errorf("upstream address has conflicting scheme (h2c://) and port (:443, the HTTPS port)")
+ if toURL.Scheme == "h2c" && port == "443" {
+ return parsedAddr{}, fmt.Errorf("upstream address has conflicting scheme (h2c://) and port (:443, the HTTPS port)")
}
// if port is missing, attempt to infer from scheme
- if toURL.Port() == "" {
- var toPort string
+ if port == "" {
switch toURL.Scheme {
case "", "http", "h2c":
- toPort = "80"
+ port = "80"
case "https":
- toPort = "443"
+ port = "443"
}
- toURL.Host = net.JoinHostPort(toURL.Hostname(), toPort)
}
- scheme, host, port = toURL.Scheme, toURL.Hostname(), toURL.Port()
+ scheme, host = toURL.Scheme, toURL.Hostname()
} else {
var err error
network, host, port, err = caddy.SplitNetworkAddress(upstreamAddr)
@@ -93,18 +147,5 @@ func parseUpstreamDialAddress(upstreamAddr string) (string, string, error) {
network = "unix"
scheme = "h2c"
}
-
- // for simplest possible config, we only need to include
- // the network portion if the user specified one
- if network != "" {
- return caddy.JoinNetworkAddress(network, host, port), scheme, nil
- }
-
- // if the host is a placeholder, then we don't want to join with an empty port,
- // because that would just append an extra ':' at the end of the address.
- if port == "" && strings.Contains(host, "{") {
- return host, scheme, nil
- }
-
- return net.JoinHostPort(host, port), scheme, nil
+ return parsedAddr{network, scheme, host, port, true}, nil
}
diff --git a/modules/caddyhttp/reverseproxy/addresses_test.go b/modules/caddyhttp/reverseproxy/addresses_test.go
index 6355c75..0c51419 100644
--- a/modules/caddyhttp/reverseproxy/addresses_test.go
+++ b/modules/caddyhttp/reverseproxy/addresses_test.go
@@ -150,6 +150,24 @@ func TestParseUpstreamDialAddress(t *testing.T) {
expectScheme: "h2c",
},
{
+ input: "localhost:1001-1009",
+ expectHostPort: "localhost:1001-1009",
+ },
+ {
+ input: "{host}:1001-1009",
+ expectHostPort: "{host}:1001-1009",
+ },
+ {
+ input: "http://localhost:1001-1009",
+ expectHostPort: "localhost:1001-1009",
+ expectScheme: "http",
+ },
+ {
+ input: "https://localhost:1001-1009",
+ expectHostPort: "localhost:1001-1009",
+ expectScheme: "https",
+ },
+ {
input: "unix//var/php.sock",
expectHostPort: "unix//var/php.sock",
},
@@ -197,6 +215,26 @@ func TestParseUpstreamDialAddress(t *testing.T) {
expectErr: true,
},
{
+ input: "http://localhost:8001-8002-8003",
+ expectErr: true,
+ },
+ {
+ input: "http://localhost:8001-8002/foo:bar",
+ expectErr: true,
+ },
+ {
+ input: "http://localhost:8001-8002/foo:1",
+ expectErr: true,
+ },
+ {
+ input: "http://localhost:8001-8002/foo:1-2",
+ expectErr: true,
+ },
+ {
+ input: "http://localhost:8001-8002#foo:1",
+ expectErr: true,
+ },
+ {
input: "http://foo:443",
expectErr: true,
},
@@ -227,18 +265,18 @@ func TestParseUpstreamDialAddress(t *testing.T) {
expectScheme: "h2c",
},
} {
- actualHostPort, actualScheme, err := parseUpstreamDialAddress(tc.input)
+ actualAddr, err := parseUpstreamDialAddress(tc.input)
if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got %v", i, err)
}
if !tc.expectErr && err != nil {
t.Errorf("Test %d: Expected no error but got %v", i, err)
}
- if actualHostPort != tc.expectHostPort {
- t.Errorf("Test %d: Expected host and port '%s' but got '%s'", i, tc.expectHostPort, actualHostPort)
+ if actualAddr.dialAddr() != tc.expectHostPort {
+ t.Errorf("Test %d: input %s: Expected host and port '%s' but got '%s'", i, tc.input, tc.expectHostPort, actualAddr.dialAddr())
}
- if actualScheme != tc.expectScheme {
- t.Errorf("Test %d: Expected scheme '%s' but got '%s'", i, tc.expectScheme, actualScheme)
+ if actualAddr.scheme != tc.expectScheme {
+ t.Errorf("Test %d: Expected scheme '%s' but got '%s'", i, tc.expectScheme, actualAddr.scheme)
}
}
}
diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go
index 1211188..bcbe744 100644
--- a/modules/caddyhttp/reverseproxy/caddyfile.go
+++ b/modules/caddyhttp/reverseproxy/caddyfile.go
@@ -15,12 +15,14 @@
package reverseproxy
import (
- "net"
+ "fmt"
"net/http"
"reflect"
"strconv"
"strings"
+ "github.com/dustin/go-humanize"
+
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
@@ -28,7 +30,6 @@ import (
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/caddyserver/caddy/v2/modules/caddyhttp/headers"
"github.com/caddyserver/caddy/v2/modules/caddyhttp/rewrite"
- "github.com/dustin/go-humanize"
)
func init() {
@@ -83,10 +84,13 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
// unhealthy_request_count <num>
//
// # streaming
-// flush_interval <duration>
+// flush_interval <duration>
// buffer_requests
// buffer_responses
-// max_buffer_size <size>
+// max_buffer_size <size>
+// stream_timeout <duration>
+// stream_close_delay <duration>
+// trace_logs
//
// # request manipulation
// trusted_proxies [private_ranges] <ranges...>
@@ -142,16 +146,9 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
h.responseMatchers = make(map[string]caddyhttp.ResponseMatcher)
// appendUpstream creates an upstream for address and adds
- // it to the list. If the address starts with "srv+" it is
- // treated as a SRV-based upstream, and any port will be
- // dropped.
+ // it to the list.
appendUpstream := func(address string) error {
- isSRV := strings.HasPrefix(address, "srv+")
- if isSRV {
- address = strings.TrimPrefix(address, "srv+")
- }
-
- dialAddr, scheme, err := parseUpstreamDialAddress(address)
+ pa, err := parseUpstreamDialAddress(address)
if err != nil {
return d.WrapErr(err)
}
@@ -159,573 +156,641 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
// the underlying JSON does not yet support different
// transports (protocols or schemes) to each backend,
// so we remember the last one we see and compare them
- if commonScheme != "" && scheme != commonScheme {
+
+ switch pa.scheme {
+ case "wss":
+ return d.Errf("the scheme wss:// is only supported in browsers; use https:// instead")
+ case "ws":
+ return d.Errf("the scheme ws:// is only supported in browsers; use http:// instead")
+ case "https", "http", "h2c", "":
+ // Do nothing or handle the valid schemes
+ default:
+ return d.Errf("unsupported URL scheme %s://", pa.scheme)
+ }
+
+ if commonScheme != "" && pa.scheme != commonScheme {
return d.Errf("for now, all proxy upstreams must use the same scheme (transport protocol); expecting '%s://' but got '%s://'",
- commonScheme, scheme)
+ commonScheme, pa.scheme)
}
- commonScheme = scheme
+ commonScheme = pa.scheme
- if isSRV {
- if host, _, err := net.SplitHostPort(dialAddr); err == nil {
- dialAddr = host
- }
- h.Upstreams = append(h.Upstreams, &Upstream{LookupSRV: dialAddr})
+ // if the port of upstream address contains a placeholder, only wrap it with the `Upstream` struct,
+ // delaying actual resolution of the address until request time.
+ if pa.replaceablePort() {
+ h.Upstreams = append(h.Upstreams, &Upstream{Dial: pa.dialAddr()})
+ return nil
+ }
+ parsedAddr, err := caddy.ParseNetworkAddress(pa.dialAddr())
+ if err != nil {
+ return d.WrapErr(err)
+ }
+
+ if pa.isUnix() || !pa.rangedPort() {
+ // unix networks don't have ports
+ h.Upstreams = append(h.Upstreams, &Upstream{
+ Dial: pa.dialAddr(),
+ })
} else {
- h.Upstreams = append(h.Upstreams, &Upstream{Dial: dialAddr})
+ // expand a port range into multiple upstreams
+ for i := parsedAddr.StartPort; i <= parsedAddr.EndPort; i++ {
+ h.Upstreams = append(h.Upstreams, &Upstream{
+ Dial: caddy.JoinNetworkAddress("", parsedAddr.Host, fmt.Sprint(i)),
+ })
+ }
}
+
return nil
}
- for d.Next() {
- for _, up := range d.RemainingArgs() {
- err := appendUpstream(up)
+ d.Next() // consume the directive name
+ for _, up := range d.RemainingArgs() {
+ err := appendUpstream(up)
+ if err != nil {
+ return fmt.Errorf("parsing upstream '%s': %w", up, err)
+ }
+ }
+
+ for nesting := d.Nesting(); d.NextBlock(nesting); {
+ // if the subdirective has an "@" prefix then we
+ // parse it as a response matcher for use with "handle_response"
+ if strings.HasPrefix(d.Val(), matcherPrefix) {
+ err := caddyhttp.ParseNamedResponseMatcher(d.NewFromNextSegment(), h.responseMatchers)
if err != nil {
return err
}
+ continue
}
- for d.NextBlock(0) {
- // if the subdirective has an "@" prefix then we
- // parse it as a response matcher for use with "handle_response"
- if strings.HasPrefix(d.Val(), matcherPrefix) {
- err := caddyhttp.ParseNamedResponseMatcher(d.NewFromNextSegment(), h.responseMatchers)
+ switch d.Val() {
+ case "to":
+ args := d.RemainingArgs()
+ if len(args) == 0 {
+ return d.ArgErr()
+ }
+ for _, up := range args {
+ err := appendUpstream(up)
if err != nil {
- return err
+ return fmt.Errorf("parsing upstream '%s': %w", up, err)
}
- continue
}
- switch d.Val() {
- case "to":
- args := d.RemainingArgs()
- if len(args) == 0 {
- return d.ArgErr()
- }
- for _, up := range args {
- err := appendUpstream(up)
- if err != nil {
- return err
- }
- }
+ case "dynamic":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.DynamicUpstreams != nil {
+ return d.Err("dynamic upstreams already specified")
+ }
+ dynModule := d.Val()
+ modID := "http.reverse_proxy.upstreams." + dynModule
+ unm, err := caddyfile.UnmarshalModule(d, modID)
+ if err != nil {
+ return err
+ }
+ source, ok := unm.(UpstreamSource)
+ if !ok {
+ return d.Errf("module %s (%T) is not an UpstreamSource", modID, unm)
+ }
+ h.DynamicUpstreamsRaw = caddyconfig.JSONModuleObject(source, "source", dynModule, nil)
- case "dynamic":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.DynamicUpstreams != nil {
- return d.Err("dynamic upstreams already specified")
- }
- dynModule := d.Val()
- modID := "http.reverse_proxy.upstreams." + dynModule
- unm, err := caddyfile.UnmarshalModule(d, modID)
- if err != nil {
- return err
- }
- source, ok := unm.(UpstreamSource)
- if !ok {
- return d.Errf("module %s (%T) is not an UpstreamSource", modID, unm)
- }
- h.DynamicUpstreamsRaw = caddyconfig.JSONModuleObject(source, "source", dynModule, nil)
+ case "lb_policy":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil {
+ return d.Err("load balancing selection policy already specified")
+ }
+ name := d.Val()
+ modID := "http.reverse_proxy.selection_policies." + name
+ unm, err := caddyfile.UnmarshalModule(d, modID)
+ if err != nil {
+ return err
+ }
+ sel, ok := unm.(Selector)
+ if !ok {
+ return d.Errf("module %s (%T) is not a reverseproxy.Selector", modID, unm)
+ }
+ if h.LoadBalancing == nil {
+ h.LoadBalancing = new(LoadBalancing)
+ }
+ h.LoadBalancing.SelectionPolicyRaw = caddyconfig.JSONModuleObject(sel, "policy", name, nil)
- case "lb_policy":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil {
- return d.Err("load balancing selection policy already specified")
- }
- name := d.Val()
- modID := "http.reverse_proxy.selection_policies." + name
- unm, err := caddyfile.UnmarshalModule(d, modID)
- if err != nil {
- return err
- }
- sel, ok := unm.(Selector)
- if !ok {
- return d.Errf("module %s (%T) is not a reverseproxy.Selector", modID, unm)
- }
- if h.LoadBalancing == nil {
- h.LoadBalancing = new(LoadBalancing)
- }
- h.LoadBalancing.SelectionPolicyRaw = caddyconfig.JSONModuleObject(sel, "policy", name, nil)
+ case "lb_retries":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ tries, err := strconv.Atoi(d.Val())
+ if err != nil {
+ return d.Errf("bad lb_retries number '%s': %v", d.Val(), err)
+ }
+ if h.LoadBalancing == nil {
+ h.LoadBalancing = new(LoadBalancing)
+ }
+ h.LoadBalancing.Retries = tries
- case "lb_retries":
- if !d.NextArg() {
- return d.ArgErr()
- }
- tries, err := strconv.Atoi(d.Val())
- if err != nil {
- return d.Errf("bad lb_retries number '%s': %v", d.Val(), err)
- }
- if h.LoadBalancing == nil {
- h.LoadBalancing = new(LoadBalancing)
- }
- h.LoadBalancing.Retries = tries
+ case "lb_try_duration":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.LoadBalancing == nil {
+ h.LoadBalancing = new(LoadBalancing)
+ }
+ dur, err := caddy.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad duration value %s: %v", d.Val(), err)
+ }
+ h.LoadBalancing.TryDuration = caddy.Duration(dur)
- case "lb_try_duration":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.LoadBalancing == nil {
- h.LoadBalancing = new(LoadBalancing)
- }
- dur, err := caddy.ParseDuration(d.Val())
- if err != nil {
- return d.Errf("bad duration value %s: %v", d.Val(), err)
- }
- h.LoadBalancing.TryDuration = caddy.Duration(dur)
+ case "lb_try_interval":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.LoadBalancing == nil {
+ h.LoadBalancing = new(LoadBalancing)
+ }
+ dur, err := caddy.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad interval value '%s': %v", d.Val(), err)
+ }
+ h.LoadBalancing.TryInterval = caddy.Duration(dur)
- case "lb_try_interval":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.LoadBalancing == nil {
- h.LoadBalancing = new(LoadBalancing)
- }
- dur, err := caddy.ParseDuration(d.Val())
- if err != nil {
- return d.Errf("bad interval value '%s': %v", d.Val(), err)
- }
- h.LoadBalancing.TryInterval = caddy.Duration(dur)
+ case "lb_retry_match":
+ matcherSet, err := caddyhttp.ParseCaddyfileNestedMatcherSet(d)
+ if err != nil {
+ return d.Errf("failed to parse lb_retry_match: %v", err)
+ }
+ if h.LoadBalancing == nil {
+ h.LoadBalancing = new(LoadBalancing)
+ }
+ h.LoadBalancing.RetryMatchRaw = append(h.LoadBalancing.RetryMatchRaw, matcherSet)
- case "lb_retry_match":
- matcherSet, err := caddyhttp.ParseCaddyfileNestedMatcherSet(d)
- if err != nil {
- return d.Errf("failed to parse lb_retry_match: %v", err)
- }
- if h.LoadBalancing == nil {
- h.LoadBalancing = new(LoadBalancing)
- }
- h.LoadBalancing.RetryMatchRaw = append(h.LoadBalancing.RetryMatchRaw, matcherSet)
+ case "health_uri":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Active == nil {
+ h.HealthChecks.Active = new(ActiveHealthChecks)
+ }
+ h.HealthChecks.Active.URI = d.Val()
- case "health_uri":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.HealthChecks == nil {
- h.HealthChecks = new(HealthChecks)
- }
- if h.HealthChecks.Active == nil {
- h.HealthChecks.Active = new(ActiveHealthChecks)
- }
- h.HealthChecks.Active.URI = d.Val()
+ case "health_path":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Active == nil {
+ h.HealthChecks.Active = new(ActiveHealthChecks)
+ }
+ h.HealthChecks.Active.Path = d.Val()
+ caddy.Log().Named("config.adapter.caddyfile").Warn("the 'health_path' subdirective is deprecated, please use 'health_uri' instead!")
- case "health_path":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.HealthChecks == nil {
- h.HealthChecks = new(HealthChecks)
- }
- if h.HealthChecks.Active == nil {
- h.HealthChecks.Active = new(ActiveHealthChecks)
- }
- h.HealthChecks.Active.Path = d.Val()
- caddy.Log().Named("config.adapter.caddyfile").Warn("the 'health_path' subdirective is deprecated, please use 'health_uri' instead!")
+ case "health_port":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Active == nil {
+ h.HealthChecks.Active = new(ActiveHealthChecks)
+ }
+ portNum, err := strconv.Atoi(d.Val())
+ if err != nil {
+ return d.Errf("bad port number '%s': %v", d.Val(), err)
+ }
+ h.HealthChecks.Active.Port = portNum
- case "health_port":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.HealthChecks == nil {
- h.HealthChecks = new(HealthChecks)
- }
- if h.HealthChecks.Active == nil {
- h.HealthChecks.Active = new(ActiveHealthChecks)
+ case "health_headers":
+ healthHeaders := make(http.Header)
+ for nesting := d.Nesting(); d.NextBlock(nesting); {
+ key := d.Val()
+ values := d.RemainingArgs()
+ if len(values) == 0 {
+ values = append(values, "")
}
- portNum, err := strconv.Atoi(d.Val())
- if err != nil {
- return d.Errf("bad port number '%s': %v", d.Val(), err)
- }
- h.HealthChecks.Active.Port = portNum
+ healthHeaders[key] = append(healthHeaders[key], values...)
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Active == nil {
+ h.HealthChecks.Active = new(ActiveHealthChecks)
+ }
+ h.HealthChecks.Active.Headers = healthHeaders
- case "health_headers":
- healthHeaders := make(http.Header)
- for nesting := d.Nesting(); d.NextBlock(nesting); {
- key := d.Val()
- values := d.RemainingArgs()
- if len(values) == 0 {
- values = append(values, "")
- }
- healthHeaders[key] = values
- }
- if h.HealthChecks == nil {
- h.HealthChecks = new(HealthChecks)
- }
- if h.HealthChecks.Active == nil {
- h.HealthChecks.Active = new(ActiveHealthChecks)
- }
- h.HealthChecks.Active.Headers = healthHeaders
+ case "health_interval":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Active == nil {
+ h.HealthChecks.Active = new(ActiveHealthChecks)
+ }
+ dur, err := caddy.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad interval value %s: %v", d.Val(), err)
+ }
+ h.HealthChecks.Active.Interval = caddy.Duration(dur)
- case "health_interval":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.HealthChecks == nil {
- h.HealthChecks = new(HealthChecks)
- }
- if h.HealthChecks.Active == nil {
- h.HealthChecks.Active = new(ActiveHealthChecks)
- }
- dur, err := caddy.ParseDuration(d.Val())
- if err != nil {
- return d.Errf("bad interval value %s: %v", d.Val(), err)
- }
- h.HealthChecks.Active.Interval = caddy.Duration(dur)
+ case "health_timeout":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Active == nil {
+ h.HealthChecks.Active = new(ActiveHealthChecks)
+ }
+ dur, err := caddy.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad timeout value %s: %v", d.Val(), err)
+ }
+ h.HealthChecks.Active.Timeout = caddy.Duration(dur)
- case "health_timeout":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.HealthChecks == nil {
- h.HealthChecks = new(HealthChecks)
- }
- if h.HealthChecks.Active == nil {
- h.HealthChecks.Active = new(ActiveHealthChecks)
- }
- dur, err := caddy.ParseDuration(d.Val())
- if err != nil {
- return d.Errf("bad timeout value %s: %v", d.Val(), err)
- }
- h.HealthChecks.Active.Timeout = caddy.Duration(dur)
+ case "health_status":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Active == nil {
+ h.HealthChecks.Active = new(ActiveHealthChecks)
+ }
+ val := d.Val()
+ if len(val) == 3 && strings.HasSuffix(val, "xx") {
+ val = val[:1]
+ }
+ statusNum, err := strconv.Atoi(val)
+ if err != nil {
+ return d.Errf("bad status value '%s': %v", d.Val(), err)
+ }
+ h.HealthChecks.Active.ExpectStatus = statusNum
- case "health_status":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.HealthChecks == nil {
- h.HealthChecks = new(HealthChecks)
- }
- if h.HealthChecks.Active == nil {
- h.HealthChecks.Active = new(ActiveHealthChecks)
- }
- val := d.Val()
- if len(val) == 3 && strings.HasSuffix(val, "xx") {
- val = val[:1]
- }
- statusNum, err := strconv.Atoi(val)
- if err != nil {
- return d.Errf("bad status value '%s': %v", d.Val(), err)
- }
- h.HealthChecks.Active.ExpectStatus = statusNum
+ case "health_body":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Active == nil {
+ h.HealthChecks.Active = new(ActiveHealthChecks)
+ }
+ h.HealthChecks.Active.ExpectBody = d.Val()
- case "health_body":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.HealthChecks == nil {
- h.HealthChecks = new(HealthChecks)
- }
- if h.HealthChecks.Active == nil {
- h.HealthChecks.Active = new(ActiveHealthChecks)
- }
- h.HealthChecks.Active.ExpectBody = d.Val()
+ case "max_fails":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Passive == nil {
+ h.HealthChecks.Passive = new(PassiveHealthChecks)
+ }
+ maxFails, err := strconv.Atoi(d.Val())
+ if err != nil {
+ return d.Errf("invalid maximum fail count '%s': %v", d.Val(), err)
+ }
+ h.HealthChecks.Passive.MaxFails = maxFails
- case "max_fails":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.HealthChecks == nil {
- h.HealthChecks = new(HealthChecks)
- }
- if h.HealthChecks.Passive == nil {
- h.HealthChecks.Passive = new(PassiveHealthChecks)
+ case "fail_duration":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Passive == nil {
+ h.HealthChecks.Passive = new(PassiveHealthChecks)
+ }
+ dur, err := caddy.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad duration value '%s': %v", d.Val(), err)
+ }
+ h.HealthChecks.Passive.FailDuration = caddy.Duration(dur)
+
+ case "unhealthy_request_count":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Passive == nil {
+ h.HealthChecks.Passive = new(PassiveHealthChecks)
+ }
+ maxConns, err := strconv.Atoi(d.Val())
+ if err != nil {
+ return d.Errf("invalid maximum connection count '%s': %v", d.Val(), err)
+ }
+ h.HealthChecks.Passive.UnhealthyRequestCount = maxConns
+
+ case "unhealthy_status":
+ args := d.RemainingArgs()
+ if len(args) == 0 {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Passive == nil {
+ h.HealthChecks.Passive = new(PassiveHealthChecks)
+ }
+ for _, arg := range args {
+ if len(arg) == 3 && strings.HasSuffix(arg, "xx") {
+ arg = arg[:1]
}
- maxFails, err := strconv.Atoi(d.Val())
+ statusNum, err := strconv.Atoi(arg)
if err != nil {
- return d.Errf("invalid maximum fail count '%s': %v", d.Val(), err)
+ return d.Errf("bad status value '%s': %v", d.Val(), err)
}
- h.HealthChecks.Passive.MaxFails = maxFails
+ h.HealthChecks.Passive.UnhealthyStatus = append(h.HealthChecks.Passive.UnhealthyStatus, statusNum)
+ }
- case "fail_duration":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.HealthChecks == nil {
- h.HealthChecks = new(HealthChecks)
- }
- if h.HealthChecks.Passive == nil {
- h.HealthChecks.Passive = new(PassiveHealthChecks)
- }
+ case "unhealthy_latency":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Passive == nil {
+ h.HealthChecks.Passive = new(PassiveHealthChecks)
+ }
+ dur, err := caddy.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad duration value '%s': %v", d.Val(), err)
+ }
+ h.HealthChecks.Passive.UnhealthyLatency = caddy.Duration(dur)
+
+ case "flush_interval":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if fi, err := strconv.Atoi(d.Val()); err == nil {
+ h.FlushInterval = caddy.Duration(fi)
+ } else {
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("bad duration value '%s': %v", d.Val(), err)
}
- h.HealthChecks.Passive.FailDuration = caddy.Duration(dur)
+ h.FlushInterval = caddy.Duration(dur)
+ }
- case "unhealthy_request_count":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.HealthChecks == nil {
- h.HealthChecks = new(HealthChecks)
- }
- if h.HealthChecks.Passive == nil {
- h.HealthChecks.Passive = new(PassiveHealthChecks)
- }
- maxConns, err := strconv.Atoi(d.Val())
+ case "request_buffers", "response_buffers":
+ subdir := d.Val()
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ val := d.Val()
+ var size int64
+ if val == "unlimited" {
+ size = -1
+ } else {
+ usize, err := humanize.ParseBytes(val)
if err != nil {
- return d.Errf("invalid maximum connection count '%s': %v", d.Val(), err)
+ return d.Errf("invalid byte size '%s': %v", val, err)
}
- h.HealthChecks.Passive.UnhealthyRequestCount = maxConns
+ size = int64(usize)
+ }
+ if d.NextArg() {
+ return d.ArgErr()
+ }
+ if subdir == "request_buffers" {
+ h.RequestBuffers = size
+ } else if subdir == "response_buffers" {
+ h.ResponseBuffers = size
+ }
- case "unhealthy_status":
- args := d.RemainingArgs()
- if len(args) == 0 {
- return d.ArgErr()
- }
- if h.HealthChecks == nil {
- h.HealthChecks = new(HealthChecks)
- }
- if h.HealthChecks.Passive == nil {
- h.HealthChecks.Passive = new(PassiveHealthChecks)
- }
- for _, arg := range args {
- if len(arg) == 3 && strings.HasSuffix(arg, "xx") {
- arg = arg[:1]
- }
- statusNum, err := strconv.Atoi(arg)
- if err != nil {
- return d.Errf("bad status value '%s': %v", d.Val(), err)
- }
- h.HealthChecks.Passive.UnhealthyStatus = append(h.HealthChecks.Passive.UnhealthyStatus, statusNum)
- }
+ // TODO: These three properties are deprecated; remove them sometime after v2.6.4
+ case "buffer_requests": // TODO: deprecated
+ if d.NextArg() {
+ return d.ArgErr()
+ }
+ caddy.Log().Named("config.adapter.caddyfile").Warn("DEPRECATED: buffer_requests: use request_buffers instead (with a maximum buffer size)")
+ h.DeprecatedBufferRequests = true
+ case "buffer_responses": // TODO: deprecated
+ if d.NextArg() {
+ return d.ArgErr()
+ }
+ caddy.Log().Named("config.adapter.caddyfile").Warn("DEPRECATED: buffer_responses: use response_buffers instead (with a maximum buffer size)")
+ h.DeprecatedBufferResponses = true
+ case "max_buffer_size": // TODO: deprecated
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ size, err := humanize.ParseBytes(d.Val())
+ if err != nil {
+ return d.Errf("invalid byte size '%s': %v", d.Val(), err)
+ }
+ if d.NextArg() {
+ return d.ArgErr()
+ }
+ caddy.Log().Named("config.adapter.caddyfile").Warn("DEPRECATED: max_buffer_size: use request_buffers and/or response_buffers instead (with maximum buffer sizes)")
+ h.DeprecatedMaxBufferSize = int64(size)
- case "unhealthy_latency":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.HealthChecks == nil {
- h.HealthChecks = new(HealthChecks)
- }
- if h.HealthChecks.Passive == nil {
- h.HealthChecks.Passive = new(PassiveHealthChecks)
- }
+ case "stream_timeout":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if fi, err := strconv.Atoi(d.Val()); err == nil {
+ h.StreamTimeout = caddy.Duration(fi)
+ } else {
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("bad duration value '%s': %v", d.Val(), err)
}
- h.HealthChecks.Passive.UnhealthyLatency = caddy.Duration(dur)
-
- case "flush_interval":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if fi, err := strconv.Atoi(d.Val()); err == nil {
- h.FlushInterval = caddy.Duration(fi)
- } else {
- dur, err := caddy.ParseDuration(d.Val())
- if err != nil {
- return d.Errf("bad duration value '%s': %v", d.Val(), err)
- }
- h.FlushInterval = caddy.Duration(dur)
- }
+ h.StreamTimeout = caddy.Duration(dur)
+ }
- case "request_buffers", "response_buffers":
- subdir := d.Val()
- if !d.NextArg() {
- return d.ArgErr()
- }
- size, err := humanize.ParseBytes(d.Val())
+ case "stream_close_delay":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if fi, err := strconv.Atoi(d.Val()); err == nil {
+ h.StreamCloseDelay = caddy.Duration(fi)
+ } else {
+ dur, err := caddy.ParseDuration(d.Val())
if err != nil {
- return d.Errf("invalid byte size '%s': %v", d.Val(), err)
- }
- if d.NextArg() {
- return d.ArgErr()
- }
- if subdir == "request_buffers" {
- h.RequestBuffers = int64(size)
- } else if subdir == "response_buffers" {
- h.ResponseBuffers = int64(size)
-
+ return d.Errf("bad duration value '%s': %v", d.Val(), err)
}
+ h.StreamCloseDelay = caddy.Duration(dur)
+ }
- // TODO: These three properties are deprecated; remove them sometime after v2.6.4
- case "buffer_requests": // TODO: deprecated
- if d.NextArg() {
- return d.ArgErr()
+ case "trusted_proxies":
+ for d.NextArg() {
+ if d.Val() == "private_ranges" {
+ h.TrustedProxies = append(h.TrustedProxies, caddyhttp.PrivateRangesCIDR()...)
+ continue
}
- caddy.Log().Named("config.adapter.caddyfile").Warn("DEPRECATED: buffer_requests: use request_buffers instead (with a maximum buffer size)")
- h.DeprecatedBufferRequests = true
- case "buffer_responses": // TODO: deprecated
- if d.NextArg() {
- return d.ArgErr()
- }
- caddy.Log().Named("config.adapter.caddyfile").Warn("DEPRECATED: buffer_responses: use response_buffers instead (with a maximum buffer size)")
- h.DeprecatedBufferResponses = true
- case "max_buffer_size": // TODO: deprecated
- if !d.NextArg() {
- return d.ArgErr()
- }
- size, err := humanize.ParseBytes(d.Val())
- if err != nil {
- return d.Errf("invalid byte size '%s': %v", d.Val(), err)
- }
- if d.NextArg() {
- return d.ArgErr()
- }
- caddy.Log().Named("config.adapter.caddyfile").Warn("DEPRECATED: max_buffer_size: use request_buffers and/or response_buffers instead (with maximum buffer sizes)")
- h.DeprecatedMaxBufferSize = int64(size)
+ h.TrustedProxies = append(h.TrustedProxies, d.Val())
+ }
- case "trusted_proxies":
- for d.NextArg() {
- if d.Val() == "private_ranges" {
- h.TrustedProxies = append(h.TrustedProxies, caddyhttp.PrivateRangesCIDR()...)
- continue
- }
- h.TrustedProxies = append(h.TrustedProxies, d.Val())
- }
+ case "header_up":
+ var err error
- case "header_up":
- var err error
+ if h.Headers == nil {
+ h.Headers = new(headers.Handler)
+ }
+ if h.Headers.Request == nil {
+ h.Headers.Request = new(headers.HeaderOps)
+ }
+ args := d.RemainingArgs()
- if h.Headers == nil {
- h.Headers = new(headers.Handler)
+ switch len(args) {
+ case 1:
+ err = headers.CaddyfileHeaderOp(h.Headers.Request, args[0], "", "")
+ case 2:
+ // some lint checks, I guess
+ if strings.EqualFold(args[0], "host") && (args[1] == "{hostport}" || args[1] == "{http.request.hostport}") {
+ caddy.Log().Named("caddyfile").Warn("Unnecessary header_up Host: the reverse proxy's default behavior is to pass headers to the upstream")
}
- if h.Headers.Request == nil {
- h.Headers.Request = new(headers.HeaderOps)
+ if strings.EqualFold(args[0], "x-forwarded-for") && (args[1] == "{remote}" || args[1] == "{http.request.remote}" || args[1] == "{remote_host}" || args[1] == "{http.request.remote.host}") {
+ caddy.Log().Named("caddyfile").Warn("Unnecessary header_up X-Forwarded-For: the reverse proxy's default behavior is to pass headers to the upstream")
}
- args := d.RemainingArgs()
-
- switch len(args) {
- case 1:
- err = headers.CaddyfileHeaderOp(h.Headers.Request, args[0], "", "")
- case 2:
- // some lint checks, I guess
- if strings.EqualFold(args[0], "host") && (args[1] == "{hostport}" || args[1] == "{http.request.hostport}") {
- caddy.Log().Named("caddyfile").Warn("Unnecessary header_up Host: the reverse proxy's default behavior is to pass headers to the upstream")
- }
- if strings.EqualFold(args[0], "x-forwarded-for") && (args[1] == "{remote}" || args[1] == "{http.request.remote}" || args[1] == "{remote_host}" || args[1] == "{http.request.remote.host}") {
- caddy.Log().Named("caddyfile").Warn("Unnecessary header_up X-Forwarded-For: the reverse proxy's default behavior is to pass headers to the upstream")
- }
- if strings.EqualFold(args[0], "x-forwarded-proto") && (args[1] == "{scheme}" || args[1] == "{http.request.scheme}") {
- caddy.Log().Named("caddyfile").Warn("Unnecessary header_up X-Forwarded-Proto: the reverse proxy's default behavior is to pass headers to the upstream")
- }
- if strings.EqualFold(args[0], "x-forwarded-host") && (args[1] == "{host}" || args[1] == "{http.request.host}" || args[1] == "{hostport}" || args[1] == "{http.request.hostport}") {
- caddy.Log().Named("caddyfile").Warn("Unnecessary header_up X-Forwarded-Host: the reverse proxy's default behavior is to pass headers to the upstream")
- }
- err = headers.CaddyfileHeaderOp(h.Headers.Request, args[0], args[1], "")
- case 3:
- err = headers.CaddyfileHeaderOp(h.Headers.Request, args[0], args[1], args[2])
- default:
- return d.ArgErr()
+ if strings.EqualFold(args[0], "x-forwarded-proto") && (args[1] == "{scheme}" || args[1] == "{http.request.scheme}") {
+ caddy.Log().Named("caddyfile").Warn("Unnecessary header_up X-Forwarded-Proto: the reverse proxy's default behavior is to pass headers to the upstream")
}
-
- if err != nil {
- return d.Err(err.Error())
+ if strings.EqualFold(args[0], "x-forwarded-host") && (args[1] == "{host}" || args[1] == "{http.request.host}" || args[1] == "{hostport}" || args[1] == "{http.request.hostport}") {
+ caddy.Log().Named("caddyfile").Warn("Unnecessary header_up X-Forwarded-Host: the reverse proxy's default behavior is to pass headers to the upstream")
}
+ err = headers.CaddyfileHeaderOp(h.Headers.Request, args[0], args[1], "")
+ case 3:
+ err = headers.CaddyfileHeaderOp(h.Headers.Request, args[0], args[1], args[2])
+ default:
+ return d.ArgErr()
+ }
- case "header_down":
- var err error
+ if err != nil {
+ return d.Err(err.Error())
+ }
- if h.Headers == nil {
- h.Headers = new(headers.Handler)
- }
- if h.Headers.Response == nil {
- h.Headers.Response = &headers.RespHeaderOps{
- HeaderOps: new(headers.HeaderOps),
- }
- }
- args := d.RemainingArgs()
- switch len(args) {
- case 1:
- err = headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], "", "")
- case 2:
- err = headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], args[1], "")
- case 3:
- err = headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], args[1], args[2])
- default:
- return d.ArgErr()
- }
+ case "header_down":
+ var err error
- if err != nil {
- return d.Err(err.Error())
- }
-
- case "method":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.Rewrite == nil {
- h.Rewrite = &rewrite.Rewrite{}
- }
- h.Rewrite.Method = d.Val()
- if d.NextArg() {
- return d.ArgErr()
+ if h.Headers == nil {
+ h.Headers = new(headers.Handler)
+ }
+ if h.Headers.Response == nil {
+ h.Headers.Response = &headers.RespHeaderOps{
+ HeaderOps: new(headers.HeaderOps),
}
+ }
+ args := d.RemainingArgs()
+ switch len(args) {
+ case 1:
+ err = headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], "", "")
+ case 2:
+ err = headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], args[1], "")
+ case 3:
+ err = headers.CaddyfileHeaderOp(h.Headers.Response.HeaderOps, args[0], args[1], args[2])
+ default:
+ return d.ArgErr()
+ }
- case "rewrite":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.Rewrite == nil {
- h.Rewrite = &rewrite.Rewrite{}
- }
- h.Rewrite.URI = d.Val()
- if d.NextArg() {
- return d.ArgErr()
- }
+ if err != nil {
+ return d.Err(err.Error())
+ }
- case "transport":
- if !d.NextArg() {
- return d.ArgErr()
- }
- if h.TransportRaw != nil {
- return d.Err("transport already specified")
- }
- transportModuleName = d.Val()
- modID := "http.reverse_proxy.transport." + transportModuleName
- unm, err := caddyfile.UnmarshalModule(d, modID)
- if err != nil {
- return err
- }
- rt, ok := unm.(http.RoundTripper)
- if !ok {
- return d.Errf("module %s (%T) is not a RoundTripper", modID, unm)
- }
- transport = rt
+ case "method":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.Rewrite == nil {
+ h.Rewrite = &rewrite.Rewrite{}
+ }
+ h.Rewrite.Method = d.Val()
+ if d.NextArg() {
+ return d.ArgErr()
+ }
- case "handle_response":
- // delegate the parsing of handle_response to the caller,
- // since we need the httpcaddyfile.Helper to parse subroutes.
- // See h.FinalizeUnmarshalCaddyfile
- h.handleResponseSegments = append(h.handleResponseSegments, d.NewFromNextSegment())
+ case "rewrite":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.Rewrite == nil {
+ h.Rewrite = &rewrite.Rewrite{}
+ }
+ h.Rewrite.URI = d.Val()
+ if d.NextArg() {
+ return d.ArgErr()
+ }
- case "replace_status":
- args := d.RemainingArgs()
- if len(args) != 1 && len(args) != 2 {
- return d.Errf("must have one or two arguments: an optional response matcher, and a status code")
- }
+ case "transport":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.TransportRaw != nil {
+ return d.Err("transport already specified")
+ }
+ transportModuleName = d.Val()
+ modID := "http.reverse_proxy.transport." + transportModuleName
+ unm, err := caddyfile.UnmarshalModule(d, modID)
+ if err != nil {
+ return err
+ }
+ rt, ok := unm.(http.RoundTripper)
+ if !ok {
+ return d.Errf("module %s (%T) is not a RoundTripper", modID, unm)
+ }
+ transport = rt
+
+ case "handle_response":
+ // delegate the parsing of handle_response to the caller,
+ // since we need the httpcaddyfile.Helper to parse subroutes.
+ // See h.FinalizeUnmarshalCaddyfile
+ h.handleResponseSegments = append(h.handleResponseSegments, d.NewFromNextSegment())
+
+ case "replace_status":
+ args := d.RemainingArgs()
+ if len(args) != 1 && len(args) != 2 {
+ return d.Errf("must have one or two arguments: an optional response matcher, and a status code")
+ }
- responseHandler := caddyhttp.ResponseHandler{}
+ responseHandler := caddyhttp.ResponseHandler{}
- if len(args) == 2 {
- if !strings.HasPrefix(args[0], matcherPrefix) {
- return d.Errf("must use a named response matcher, starting with '@'")
- }
- foundMatcher, ok := h.responseMatchers[args[0]]
- if !ok {
- return d.Errf("no named response matcher defined with name '%s'", args[0][1:])
- }
- responseHandler.Match = &foundMatcher
- responseHandler.StatusCode = caddyhttp.WeakString(args[1])
- } else if len(args) == 1 {
- responseHandler.StatusCode = caddyhttp.WeakString(args[0])
+ if len(args) == 2 {
+ if !strings.HasPrefix(args[0], matcherPrefix) {
+ return d.Errf("must use a named response matcher, starting with '@'")
}
-
- // make sure there's no block, cause it doesn't make sense
- if d.NextBlock(1) {
- return d.Errf("cannot define routes for 'replace_status', use 'handle_response' instead.")
+ foundMatcher, ok := h.responseMatchers[args[0]]
+ if !ok {
+ return d.Errf("no named response matcher defined with name '%s'", args[0][1:])
}
+ responseHandler.Match = &foundMatcher
+ responseHandler.StatusCode = caddyhttp.WeakString(args[1])
+ } else if len(args) == 1 {
+ responseHandler.StatusCode = caddyhttp.WeakString(args[0])
+ }
- h.HandleResponse = append(
- h.HandleResponse,
- responseHandler,
- )
+ // make sure there's no block, cause it doesn't make sense
+ if d.NextBlock(1) {
+ return d.Errf("cannot define routes for 'replace_status', use 'handle_response' instead.")
+ }
- default:
- return d.Errf("unrecognized subdirective %s", d.Val())
+ h.HandleResponse = append(
+ h.HandleResponse,
+ responseHandler,
+ )
+
+ case "verbose_logs":
+ if h.VerboseLogs {
+ return d.Err("verbose_logs already specified")
}
+ h.VerboseLogs = true
+
+ default:
+ return d.Errf("unrecognized subdirective %s", d.Val())
}
}
@@ -918,6 +983,17 @@ func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
h.MaxResponseHeaderSize = int64(size)
+ case "proxy_protocol":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ switch proxyProtocol := d.Val(); proxyProtocol {
+ case "v1", "v2":
+ h.ProxyProtocol = proxyProtocol
+ default:
+ return d.Errf("invalid proxy protocol version '%s'", proxyProtocol)
+ }
+
case "dial_timeout":
if !d.NextArg() {
return d.ArgErr()
@@ -1324,6 +1400,7 @@ func (u *SRVUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
// resolvers <resolvers...>
// dial_timeout <timeout>
// dial_fallback_delay <timeout>
+// versions ipv4|ipv6
// }
func (u *AUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
for d.Next() {
@@ -1397,8 +1474,30 @@ func (u *AUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
u.FallbackDelay = caddy.Duration(dur)
+ case "versions":
+ args := d.RemainingArgs()
+ if len(args) == 0 {
+ return d.Errf("must specify at least one version")
+ }
+
+ if u.Versions == nil {
+ u.Versions = &IPVersions{}
+ }
+
+ trueBool := true
+ for _, arg := range args {
+ switch arg {
+ case "ipv4":
+ u.Versions.IPv4 = &trueBool
+ case "ipv6":
+ u.Versions.IPv6 = &trueBool
+ default:
+ return d.Errf("unsupported version: '%s'", arg)
+ }
+ }
+
default:
- return d.Errf("unrecognized srv option '%s'", d.Val())
+ return d.Errf("unrecognized a option '%s'", d.Val())
}
}
}
diff --git a/modules/caddyhttp/reverseproxy/command.go b/modules/caddyhttp/reverseproxy/command.go
index 44f4c22..11f935c 100644
--- a/modules/caddyhttp/reverseproxy/command.go
+++ b/modules/caddyhttp/reverseproxy/command.go
@@ -16,26 +16,28 @@ package reverseproxy
import (
"encoding/json"
- "flag"
"fmt"
"net/http"
"strconv"
+ "strings"
+
+ "github.com/spf13/cobra"
+ "go.uber.org/zap"
+
+ caddycmd "github.com/caddyserver/caddy/v2/cmd"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig"
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
- caddycmd "github.com/caddyserver/caddy/v2/cmd"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/caddyserver/caddy/v2/modules/caddyhttp/headers"
"github.com/caddyserver/caddy/v2/modules/caddytls"
- "go.uber.org/zap"
)
func init() {
caddycmd.RegisterCommand(caddycmd.Command{
Name: "reverse-proxy",
- Func: cmdReverseProxy,
- Usage: "[--from <addr>] [--to <addr>] [--change-host-header]",
+ Usage: `[--from <addr>] [--to <addr>] [--change-host-header] [--insecure] [--internal-certs] [--disable-redirects] [--header-up "Field: value"] [--header-down "Field: value"] [--access-log] [--debug]`,
Short: "A quick and production-ready reverse proxy",
Long: `
A simple but production-ready reverse proxy. Useful for quick deployments,
@@ -52,21 +54,33 @@ If the --from address has a host or IP, Caddy will attempt to serve the
proxy over HTTPS with a certificate (unless overridden by the HTTP scheme
or port).
-If --change-host-header is set, the Host header on the request will be modified
-from its original incoming value to the address of the upstream. (Otherwise, by
-default, all incoming headers are passed through unmodified.)
+If serving HTTPS:
+ --disable-redirects can be used to avoid binding to the HTTP port.
+ --internal-certs can be used to force issuance certs using the internal
+ CA instead of attempting to issue a public certificate.
+
+For proxying:
+ --header-up can be used to set a request header to send to the upstream.
+ --header-down can be used to set a response header to send back to the client.
+ --change-host-header sets the Host header on the request to the address
+ of the upstream, instead of defaulting to the incoming Host header.
+ This is a shortcut for --header-up "Host: {http.reverse_proxy.upstream.hostport}".
+ --insecure disables TLS verification with the upstream. WARNING: THIS
+ DISABLES SECURITY BY NOT VERIFYING THE UPSTREAM'S CERTIFICATE.
`,
- Flags: func() *flag.FlagSet {
- fs := flag.NewFlagSet("reverse-proxy", flag.ExitOnError)
- fs.String("from", "localhost", "Address on which to receive traffic")
- fs.Var(&reverseProxyCmdTo, "to", "Upstream address(es) to which traffic should be sent")
- fs.Bool("change-host-header", false, "Set upstream Host header to address of upstream")
- fs.Bool("insecure", false, "Disable TLS verification (WARNING: DISABLES SECURITY BY NOT VERIFYING SSL CERTIFICATES!)")
- fs.Bool("internal-certs", false, "Use internal CA for issuing certs")
- fs.Bool("debug", false, "Enable verbose debug logs")
- fs.Bool("disable-redirects", false, "Disable HTTP->HTTPS redirects")
- return fs
- }(),
+ CobraFunc: func(cmd *cobra.Command) {
+ cmd.Flags().StringP("from", "f", "localhost", "Address on which to receive traffic")
+ cmd.Flags().StringSliceP("to", "t", []string{}, "Upstream address(es) to which traffic should be sent")
+ cmd.Flags().BoolP("change-host-header", "c", false, "Set upstream Host header to address of upstream")
+ cmd.Flags().BoolP("insecure", "", false, "Disable TLS verification (WARNING: DISABLES SECURITY BY NOT VERIFYING TLS CERTIFICATES!)")
+ cmd.Flags().BoolP("disable-redirects", "r", false, "Disable HTTP->HTTPS redirects")
+ cmd.Flags().BoolP("internal-certs", "i", false, "Use internal CA for issuing certs")
+ cmd.Flags().StringSliceP("header-up", "H", []string{}, "Set a request header to send to the upstream (format: \"Field: value\")")
+ cmd.Flags().StringSliceP("header-down", "d", []string{}, "Set a response header to send back to the client (format: \"Field: value\")")
+ cmd.Flags().BoolP("access-log", "", false, "Enable the access log")
+ cmd.Flags().BoolP("debug", "v", false, "Enable verbose debug logs")
+ cmd.RunE = caddycmd.WrapCommandFuncForCobra(cmdReverseProxy)
+ },
})
}
@@ -76,14 +90,19 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) {
from := fs.String("from")
changeHost := fs.Bool("change-host-header")
insecure := fs.Bool("insecure")
+ disableRedir := fs.Bool("disable-redirects")
internalCerts := fs.Bool("internal-certs")
+ accessLog := fs.Bool("access-log")
debug := fs.Bool("debug")
- disableRedir := fs.Bool("disable-redirects")
httpPort := strconv.Itoa(caddyhttp.DefaultHTTPPort)
httpsPort := strconv.Itoa(caddyhttp.DefaultHTTPSPort)
- if len(reverseProxyCmdTo) == 0 {
+ to, err := fs.GetStringSlice("to")
+ if err != nil {
+ return caddy.ExitCodeFailedStartup, fmt.Errorf("invalid to flag: %v", err)
+ }
+ if len(to) == 0 {
return caddy.ExitCodeFailedStartup, fmt.Errorf("--to is required")
}
@@ -112,17 +131,17 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) {
// set up the upstream address; assume missing information from given parts
// mixing schemes isn't supported, so use first defined (if available)
- toAddresses := make([]string, len(reverseProxyCmdTo))
+ toAddresses := make([]string, len(to))
var toScheme string
- for i, toLoc := range reverseProxyCmdTo {
- addr, scheme, err := parseUpstreamDialAddress(toLoc)
+ for i, toLoc := range to {
+ addr, err := parseUpstreamDialAddress(toLoc)
if err != nil {
return caddy.ExitCodeFailedStartup, fmt.Errorf("invalid upstream address %s: %v", toLoc, err)
}
- if scheme != "" && toScheme == "" {
- toScheme = scheme
+ if addr.scheme != "" && toScheme == "" {
+ toScheme = addr.scheme
}
- toAddresses[i] = addr
+ toAddresses[i] = addr.dialAddr()
}
// proceed to build the handler and server
@@ -136,9 +155,24 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) {
upstreamPool := UpstreamPool{}
for _, toAddr := range toAddresses {
- upstreamPool = append(upstreamPool, &Upstream{
- Dial: toAddr,
- })
+ parsedAddr, err := caddy.ParseNetworkAddress(toAddr)
+ if err != nil {
+ return caddy.ExitCodeFailedStartup, fmt.Errorf("invalid upstream address %s: %v", toAddr, err)
+ }
+
+ if parsedAddr.StartPort == 0 && parsedAddr.EndPort == 0 {
+ // unix networks don't have ports
+ upstreamPool = append(upstreamPool, &Upstream{
+ Dial: toAddr,
+ })
+ } else {
+ // expand a port range into multiple upstreams
+ for i := parsedAddr.StartPort; i <= parsedAddr.EndPort; i++ {
+ upstreamPool = append(upstreamPool, &Upstream{
+ Dial: caddy.JoinNetworkAddress("", parsedAddr.Host, fmt.Sprint(i)),
+ })
+ }
+ }
}
handler := Handler{
@@ -146,16 +180,64 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) {
Upstreams: upstreamPool,
}
- if changeHost {
+ // set up header_up
+ headerUp, err := fs.GetStringSlice("header-up")
+ if err != nil {
+ return caddy.ExitCodeFailedStartup, fmt.Errorf("invalid header flag: %v", err)
+ }
+ if len(headerUp) > 0 {
+ reqHdr := make(http.Header)
+ for i, h := range headerUp {
+ key, val, found := strings.Cut(h, ":")
+ key, val = strings.TrimSpace(key), strings.TrimSpace(val)
+ if !found || key == "" || val == "" {
+ return caddy.ExitCodeFailedStartup, fmt.Errorf("header-up %d: invalid format \"%s\" (expecting \"Field: value\")", i, h)
+ }
+ reqHdr.Set(key, val)
+ }
handler.Headers = &headers.Handler{
Request: &headers.HeaderOps{
- Set: http.Header{
- "Host": []string{"{http.reverse_proxy.upstream.hostport}"},
- },
+ Set: reqHdr,
+ },
+ }
+ }
+
+ // set up header_down
+ headerDown, err := fs.GetStringSlice("header-down")
+ if err != nil {
+ return caddy.ExitCodeFailedStartup, fmt.Errorf("invalid header flag: %v", err)
+ }
+ if len(headerDown) > 0 {
+ respHdr := make(http.Header)
+ for i, h := range headerDown {
+ key, val, found := strings.Cut(h, ":")
+ key, val = strings.TrimSpace(key), strings.TrimSpace(val)
+ if !found || key == "" || val == "" {
+ return caddy.ExitCodeFailedStartup, fmt.Errorf("header-down %d: invalid format \"%s\" (expecting \"Field: value\")", i, h)
+ }
+ respHdr.Set(key, val)
+ }
+ if handler.Headers == nil {
+ handler.Headers = &headers.Handler{}
+ }
+ handler.Headers.Response = &headers.RespHeaderOps{
+ HeaderOps: &headers.HeaderOps{
+ Set: respHdr,
},
}
}
+ if changeHost {
+ if handler.Headers == nil {
+ handler.Headers = &headers.Handler{
+ Request: &headers.HeaderOps{
+ Set: http.Header{},
+ },
+ }
+ }
+ handler.Headers.Request.Set.Set("Host", "{http.reverse_proxy.upstream.hostport}")
+ }
+
route := caddyhttp.Route{
HandlersRaw: []json.RawMessage{
caddyconfig.JSONModuleObject(handler, "handler", "reverse_proxy", nil),
@@ -173,6 +255,9 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) {
Routes: caddyhttp.RouteList{route},
Listen: []string{":" + fromAddr.Port},
}
+ if accessLog {
+ server.Logs = &caddyhttp.ServerLogConfig{}
+ }
if fromAddr.Scheme == "http" {
server.AutoHTTPS = &caddyhttp.AutoHTTPSConfig{Disabled: true}
@@ -191,8 +276,8 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) {
tlsApp := caddytls.TLS{
Automation: &caddytls.AutomationConfig{
Policies: []*caddytls.AutomationPolicy{{
- Subjects: []string{fromAddr.Host},
- IssuersRaw: []json.RawMessage{json.RawMessage(`{"module":"internal"}`)},
+ SubjectsRaw: []string{fromAddr.Host},
+ IssuersRaw: []json.RawMessage{json.RawMessage(`{"module":"internal"}`)},
}},
},
}
@@ -201,7 +286,8 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) {
var false bool
cfg := &caddy.Config{
- Admin: &caddy.AdminConfig{Disabled: true,
+ Admin: &caddy.AdminConfig{
+ Disabled: true,
Config: &caddy.ConfigSettings{
Persist: &false,
},
@@ -212,7 +298,7 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) {
if debug {
cfg.Logging = &caddy.Logging{
Logs: map[string]*caddy.CustomLog{
- "default": {Level: zap.DebugLevel.CapitalString()},
+ "default": {BaseLog: caddy.BaseLog{Level: zap.DebugLevel.CapitalString()}},
},
}
}
@@ -231,6 +317,3 @@ func cmdReverseProxy(fs caddycmd.Flags) (int, error) {
select {}
}
-
-// reverseProxyCmdTo holds the parsed values from repeated use of the --to flag.
-var reverseProxyCmdTo caddycmd.StringSlice
diff --git a/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go b/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go
index 799050e..a24a3ed 100644
--- a/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go
+++ b/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go
@@ -217,25 +217,18 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error
return nil, dispenser.ArgErr()
}
fcgiTransport.Root = dispenser.Val()
- dispenser.Delete()
- dispenser.Delete()
+ dispenser.DeleteN(2)
case "split":
extensions = dispenser.RemainingArgs()
- dispenser.Delete()
- for range extensions {
- dispenser.Delete()
- }
+ dispenser.DeleteN(len(extensions) + 1)
if len(extensions) == 0 {
return nil, dispenser.ArgErr()
}
case "env":
args := dispenser.RemainingArgs()
- dispenser.Delete()
- for range args {
- dispenser.Delete()
- }
+ dispenser.DeleteN(len(args) + 1)
if len(args) != 2 {
return nil, dispenser.ArgErr()
}
@@ -246,10 +239,7 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error
case "index":
args := dispenser.RemainingArgs()
- dispenser.Delete()
- for range args {
- dispenser.Delete()
- }
+ dispenser.DeleteN(len(args) + 1)
if len(args) != 1 {
return nil, dispenser.ArgErr()
}
@@ -257,10 +247,7 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error
case "try_files":
args := dispenser.RemainingArgs()
- dispenser.Delete()
- for range args {
- dispenser.Delete()
- }
+ dispenser.DeleteN(len(args) + 1)
if len(args) < 1 {
return nil, dispenser.ArgErr()
}
@@ -268,10 +255,7 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error
case "resolve_root_symlink":
args := dispenser.RemainingArgs()
- dispenser.Delete()
- for range args {
- dispenser.Delete()
- }
+ dispenser.DeleteN(len(args) + 1)
fcgiTransport.ResolveRootSymlink = true
case "dial_timeout":
@@ -283,8 +267,7 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error
return nil, dispenser.Errf("bad timeout value %s: %v", dispenser.Val(), err)
}
fcgiTransport.DialTimeout = caddy.Duration(dur)
- dispenser.Delete()
- dispenser.Delete()
+ dispenser.DeleteN(2)
case "read_timeout":
if !dispenser.NextArg() {
@@ -295,8 +278,7 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error
return nil, dispenser.Errf("bad timeout value %s: %v", dispenser.Val(), err)
}
fcgiTransport.ReadTimeout = caddy.Duration(dur)
- dispenser.Delete()
- dispenser.Delete()
+ dispenser.DeleteN(2)
case "write_timeout":
if !dispenser.NextArg() {
@@ -307,15 +289,11 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error
return nil, dispenser.Errf("bad timeout value %s: %v", dispenser.Val(), err)
}
fcgiTransport.WriteTimeout = caddy.Duration(dur)
- dispenser.Delete()
- dispenser.Delete()
+ dispenser.DeleteN(2)
case "capture_stderr":
args := dispenser.RemainingArgs()
- dispenser.Delete()
- for range args {
- dispenser.Delete()
- }
+ dispenser.DeleteN(len(args) + 1)
fcgiTransport.CaptureStderr = true
}
}
@@ -395,6 +373,7 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error
// the rest of the config is specified by the user
// using the reverse_proxy directive syntax
+ dispenser.Next() // consume the directive name
err = rpHandler.UnmarshalCaddyfile(dispenser)
if err != nil {
return nil, err
diff --git a/modules/caddyhttp/reverseproxy/fastcgi/client.go b/modules/caddyhttp/reverseproxy/fastcgi/client.go
index ae36dd8..04513dd 100644
--- a/modules/caddyhttp/reverseproxy/fastcgi/client.go
+++ b/modules/caddyhttp/reverseproxy/fastcgi/client.go
@@ -251,7 +251,6 @@ func (c *client) Request(p map[string]string, req io.Reader) (resp *http.Respons
// Get issues a GET request to the fcgi responder.
func (c *client) Get(p map[string]string, body io.Reader, l int64) (resp *http.Response, err error) {
-
p["REQUEST_METHOD"] = "GET"
p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10)
@@ -260,7 +259,6 @@ func (c *client) Get(p map[string]string, body io.Reader, l int64) (resp *http.R
// Head issues a HEAD request to the fcgi responder.
func (c *client) Head(p map[string]string) (resp *http.Response, err error) {
-
p["REQUEST_METHOD"] = "HEAD"
p["CONTENT_LENGTH"] = "0"
@@ -269,7 +267,6 @@ func (c *client) Head(p map[string]string) (resp *http.Response, err error) {
// Options issues an OPTIONS request to the fcgi responder.
func (c *client) Options(p map[string]string) (resp *http.Response, err error) {
-
p["REQUEST_METHOD"] = "OPTIONS"
p["CONTENT_LENGTH"] = "0"
diff --git a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go
index ec194e7..31febdd 100644
--- a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go
+++ b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go
@@ -24,13 +24,13 @@ import (
"strings"
"time"
- "github.com/caddyserver/caddy/v2/modules/caddyhttp"
- "github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy"
- "github.com/caddyserver/caddy/v2/modules/caddytls"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/caddyserver/caddy/v2"
+ "github.com/caddyserver/caddy/v2/modules/caddyhttp"
+ "github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy"
+ "github.com/caddyserver/caddy/v2/modules/caddytls"
)
var noopLogger = zap.NewNop()
@@ -171,6 +171,7 @@ func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
rwc: conn,
reqID: 1,
logger: logger,
+ stderr: t.CaptureStderr,
}
// read/write timeouts
@@ -254,9 +255,7 @@ func (t Transport) buildEnv(r *http.Request) (envVars, error) {
// if we didn't get a split result here.
// See https://github.com/caddyserver/caddy/issues/3718
if pathInfo == "" {
- if remainder, ok := repl.GetString("http.matchers.file.remainder"); ok {
- pathInfo = remainder
- }
+ pathInfo, _ = repl.GetString("http.matchers.file.remainder")
}
// SCRIPT_FILENAME is the absolute path of SCRIPT_NAME
@@ -286,10 +285,7 @@ func (t Transport) buildEnv(r *http.Request) (envVars, error) {
reqHost = r.Host
}
- authUser := ""
- if val, ok := repl.Get("http.auth.user.id"); ok {
- authUser = val.(string)
- }
+ authUser, _ := repl.GetString("http.auth.user.id")
// Some variables are unused but cleared explicitly to prevent
// the parent environment from interfering.
diff --git a/modules/caddyhttp/reverseproxy/forwardauth/caddyfile.go b/modules/caddyhttp/reverseproxy/forwardauth/caddyfile.go
index cecc000..8350096 100644
--- a/modules/caddyhttp/reverseproxy/forwardauth/caddyfile.go
+++ b/modules/caddyhttp/reverseproxy/forwardauth/caddyfile.go
@@ -129,8 +129,7 @@ func parseCaddyfile(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error)
return nil, dispenser.ArgErr()
}
rpHandler.Rewrite.URI = dispenser.Val()
- dispenser.Delete()
- dispenser.Delete()
+ dispenser.DeleteN(2)
case "copy_headers":
args := dispenser.RemainingArgs()
@@ -140,13 +139,11 @@ func parseCaddyfile(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error)
args = append(args, dispenser.Val())
}
- dispenser.Delete() // directive name
+ // directive name + args
+ dispenser.DeleteN(len(args) + 1)
if hadBlock {
- dispenser.Delete() // opening brace
- dispenser.Delete() // closing brace
- }
- for range args {
- dispenser.Delete()
+ // opening & closing brace
+ dispenser.DeleteN(2)
}
for _, headerField := range args {
@@ -219,6 +216,7 @@ func parseCaddyfile(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error)
// the rest of the config is specified by the user
// using the reverse_proxy directive syntax
+ dispenser.Next() // consume the directive name
err = rpHandler.UnmarshalCaddyfile(dispenser)
if err != nil {
return nil, err
diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go
index c27b24f..ad21ccb 100644
--- a/modules/caddyhttp/reverseproxy/healthchecks.go
+++ b/modules/caddyhttp/reverseproxy/healthchecks.go
@@ -24,12 +24,12 @@ import (
"regexp"
"runtime/debug"
"strconv"
- "strings"
"time"
+ "go.uber.org/zap"
+
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
- "go.uber.org/zap"
)
// HealthChecks configures active and passive health checks.
@@ -106,6 +106,76 @@ type ActiveHealthChecks struct {
logger *zap.Logger
}
+// Provision ensures that a is set up properly before use.
+func (a *ActiveHealthChecks) Provision(ctx caddy.Context, h *Handler) error {
+ if !a.IsEnabled() {
+ return nil
+ }
+
+ // Canonicalize the header keys ahead of time, since
+ // JSON unmarshaled headers may be incorrect
+ cleaned := http.Header{}
+ for key, hdrs := range a.Headers {
+ for _, val := range hdrs {
+ cleaned.Add(key, val)
+ }
+ }
+ a.Headers = cleaned
+
+ h.HealthChecks.Active.logger = h.logger.Named("health_checker.active")
+
+ timeout := time.Duration(a.Timeout)
+ if timeout == 0 {
+ timeout = 5 * time.Second
+ }
+
+ if a.Path != "" {
+ a.logger.Warn("the 'path' option is deprecated, please use 'uri' instead!")
+ }
+
+ // parse the URI string (supports path and query)
+ if a.URI != "" {
+ parsedURI, err := url.Parse(a.URI)
+ if err != nil {
+ return err
+ }
+ a.uri = parsedURI
+ }
+
+ a.httpClient = &http.Client{
+ Timeout: timeout,
+ Transport: h.Transport,
+ }
+
+ for _, upstream := range h.Upstreams {
+ // if there's an alternative port for health-check provided in the config,
+ // then use it, otherwise use the port of upstream.
+ if a.Port != 0 {
+ upstream.activeHealthCheckPort = a.Port
+ }
+ }
+
+ if a.Interval == 0 {
+ a.Interval = caddy.Duration(30 * time.Second)
+ }
+
+ if a.ExpectBody != "" {
+ var err error
+ a.bodyRegexp, err = regexp.Compile(a.ExpectBody)
+ if err != nil {
+ return fmt.Errorf("expect_body: compiling regular expression: %v", err)
+ }
+ }
+
+ return nil
+}
+
+// IsEnabled checks if the active health checks have
+// the minimum config necessary to be enabled.
+func (a *ActiveHealthChecks) IsEnabled() bool {
+ return a.Path != "" || a.URI != "" || a.Port != 0
+}
+
// PassiveHealthChecks holds configuration related to passive
// health checks (that is, health checks which occur during
// the normal flow of request proxying).
@@ -203,7 +273,7 @@ func (h *Handler) doActiveHealthCheckForAllHosts() {
}
addr.StartPort, addr.EndPort = hcp, hcp
}
- if upstream.LookupSRV == "" && addr.PortRangeSize() != 1 {
+ if addr.PortRangeSize() != 1 {
h.HealthChecks.Active.logger.Error("multiple addresses (upstream must map to only one address)",
zap.String("address", networkAddr),
)
@@ -237,16 +307,35 @@ func (h *Handler) doActiveHealthCheckForAllHosts() {
// the host's health status fails.
func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstream *Upstream) error {
// create the URL for the request that acts as a health check
- scheme := "http"
- if ht, ok := h.Transport.(TLSTransport); ok && ht.TLSEnabled() {
- // this is kind of a hacky way to know if we should use HTTPS, but whatever
- scheme = "https"
- }
u := &url.URL{
- Scheme: scheme,
+ Scheme: "http",
Host: hostAddr,
}
+ // split the host and port if possible, override the port if configured
+ host, port, err := net.SplitHostPort(hostAddr)
+ if err != nil {
+ host = hostAddr
+ }
+ if h.HealthChecks.Active.Port != 0 {
+ port := strconv.Itoa(h.HealthChecks.Active.Port)
+ u.Host = net.JoinHostPort(host, port)
+ }
+
+ // this is kind of a hacky way to know if we should use HTTPS, but whatever
+ if tt, ok := h.Transport.(TLSTransport); ok && tt.TLSEnabled() {
+ u.Scheme = "https"
+
+ // if the port is in the except list, flip back to HTTP
+ if ht, ok := h.Transport.(*HTTPTransport); ok {
+ for _, exceptPort := range ht.TLS.ExceptPorts {
+ if exceptPort == port {
+ u.Scheme = "http"
+ }
+ }
+ }
+ }
+
// if we have a provisioned uri, use that, otherwise use
// the deprecated Path option
if h.HealthChecks.Active.uri != nil {
@@ -256,16 +345,6 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstre
u.Path = h.HealthChecks.Active.Path
}
- // adjust the port, if configured to be different
- if h.HealthChecks.Active.Port != 0 {
- portStr := strconv.Itoa(h.HealthChecks.Active.Port)
- host, _, err := net.SplitHostPort(hostAddr)
- if err != nil {
- host = hostAddr
- }
- u.Host = net.JoinHostPort(host, portStr)
- }
-
// attach dialing information to this request, as well as context values that
// may be expected by handlers of this request
ctx := h.ctx.Context
@@ -279,11 +358,17 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstre
}
ctx = context.WithValue(ctx, caddyhttp.OriginalRequestCtxKey, *req)
req = req.WithContext(ctx)
- for key, hdrs := range h.HealthChecks.Active.Headers {
- if strings.ToLower(key) == "host" {
- req.Host = h.HealthChecks.Active.Headers.Get(key)
- } else {
- req.Header[key] = hdrs
+
+ // set headers, using a replacer with only globals (env vars, system info, etc.)
+ repl := caddy.NewReplacer()
+ for key, vals := range h.HealthChecks.Active.Headers {
+ key = repl.ReplaceAll(key, "")
+ if key == "Host" {
+ req.Host = repl.ReplaceAll(h.HealthChecks.Active.Headers.Get(key), "")
+ continue
+ }
+ for _, val := range vals {
+ req.Header.Add(key, repl.ReplaceKnown(val, ""))
}
}
diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go
index a973ecb..83a39d8 100644
--- a/modules/caddyhttp/reverseproxy/hosts.go
+++ b/modules/caddyhttp/reverseproxy/hosts.go
@@ -17,8 +17,8 @@ package reverseproxy
import (
"context"
"fmt"
- "net"
"net/http"
+ "net/netip"
"strconv"
"sync/atomic"
@@ -47,15 +47,6 @@ type Upstream struct {
// backends is down. Also be aware of open proxy vulnerabilities.
Dial string `json:"dial,omitempty"`
- // DEPRECATED: Use the SRVUpstreams module instead
- // (http.reverse_proxy.upstreams.srv). This field will be
- // removed in a future version of Caddy. TODO: Remove this field.
- //
- // If DNS SRV records are used for service discovery with this
- // upstream, specify the DNS name for which to look up SRV
- // records here, instead of specifying a dial address.
- LookupSRV string `json:"lookup_srv,omitempty"`
-
// The maximum number of simultaneous requests to allow to
// this upstream. If set, overrides the global passive health
// check UnhealthyRequestCount value.
@@ -72,12 +63,10 @@ type Upstream struct {
unhealthy int32 // accessed atomically; status from active health checker
}
-func (u Upstream) String() string {
- if u.LookupSRV != "" {
- return u.LookupSRV
- }
- return u.Dial
-}
+// (pointer receiver necessary to avoid a race condition, since
+// copying the Upstream reads the 'unhealthy' field which is
+// accessed atomically)
+func (u *Upstream) String() string { return u.Dial }
// Available returns true if the remote host
// is available to receive requests. This is
@@ -109,35 +98,21 @@ func (u *Upstream) Full() bool {
}
// fillDialInfo returns a filled DialInfo for upstream u, using the request
-// context. If the upstream has a SRV lookup configured, that is done and a
-// returned address is chosen; otherwise, the upstream's regular dial address
-// field is used. Note that the returned value is not a pointer.
+// context. Note that the returned value is not a pointer.
func (u *Upstream) fillDialInfo(r *http.Request) (DialInfo, error) {
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
var addr caddy.NetworkAddress
- if u.LookupSRV != "" {
- // perform DNS lookup for SRV records and choose one - TODO: deprecated
- srvName := repl.ReplaceAll(u.LookupSRV, "")
- _, records, err := net.DefaultResolver.LookupSRV(r.Context(), "", "", srvName)
- if err != nil {
- return DialInfo{}, err
- }
- addr.Network = "tcp"
- addr.Host = records[0].Target
- addr.StartPort, addr.EndPort = uint(records[0].Port), uint(records[0].Port)
- } else {
- // use provided dial address
- var err error
- dial := repl.ReplaceAll(u.Dial, "")
- addr, err = caddy.ParseNetworkAddress(dial)
- if err != nil {
- return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", u.Dial, dial, err)
- }
- if numPorts := addr.PortRangeSize(); numPorts != 1 {
- return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d",
- u.Dial, dial, numPorts)
- }
+ // use provided dial address
+ var err error
+ dial := repl.ReplaceAll(u.Dial, "")
+ addr, err = caddy.ParseNetworkAddress(dial)
+ if err != nil {
+ return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", u.Dial, dial, err)
+ }
+ if numPorts := addr.PortRangeSize(); numPorts != 1 {
+ return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d",
+ u.Dial, dial, numPorts)
}
return DialInfo{
@@ -259,3 +234,13 @@ var hosts = caddy.NewUsagePool()
// dialInfoVarKey is the key used for the variable that holds
// the dial info for the upstream connection.
const dialInfoVarKey = "reverse_proxy.dial_info"
+
+// proxyProtocolInfoVarKey is the key used for the variable that holds
+// the proxy protocol info for the upstream connection.
+const proxyProtocolInfoVarKey = "reverse_proxy.proxy_protocol_info"
+
+// ProxyProtocolInfo contains information needed to write proxy protocol to a
+// connection to an upstream host.
+type ProxyProtocolInfo struct {
+ AddrPort netip.AddrPort
+}
diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go
index ec5d2f2..9290f7e 100644
--- a/modules/caddyhttp/reverseproxy/httptransport.go
+++ b/modules/caddyhttp/reverseproxy/httptransport.go
@@ -28,10 +28,13 @@ import (
"strings"
"time"
- "github.com/caddyserver/caddy/v2"
- "github.com/caddyserver/caddy/v2/modules/caddytls"
+ "github.com/mastercactapus/proxyprotocol"
"go.uber.org/zap"
"golang.org/x/net/http2"
+
+ "github.com/caddyserver/caddy/v2"
+ "github.com/caddyserver/caddy/v2/modules/caddyhttp"
+ "github.com/caddyserver/caddy/v2/modules/caddytls"
)
func init() {
@@ -64,6 +67,10 @@ type HTTPTransport struct {
// Maximum number of connections per host. Default: 0 (no limit)
MaxConnsPerHost int `json:"max_conns_per_host,omitempty"`
+ // If non-empty, which PROXY protocol version to send when
+ // connecting to an upstream. Default: off.
+ ProxyProtocol string `json:"proxy_protocol,omitempty"`
+
// How long to wait before timing out trying to connect to
// an upstream. Default: `3s`.
DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`
@@ -172,12 +179,19 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
}
}
- // Set up the dialer to pull the correct information from the context
dialContext := func(ctx context.Context, network, address string) (net.Conn, error) {
- // the proper dialing information should be embedded into the request's context
+ // For unix socket upstreams, we need to recover the dial info from
+ // the request's context, because the Host on the request's URL
+ // will have been modified by directing the request, overwriting
+ // the unix socket filename.
+ // Also, we need to avoid overwriting the address at this point
+ // when not necessary, because http.ProxyFromEnvironment may have
+ // modified the address according to the user's env proxy config.
if dialInfo, ok := GetDialInfo(ctx); ok {
- network = dialInfo.Network
- address = dialInfo.Address
+ if strings.HasPrefix(dialInfo.Network, "unix") {
+ network = dialInfo.Network
+ address = dialInfo.Address
+ }
}
conn, err := dialer.DialContext(ctx, network, address)
@@ -188,8 +202,59 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
return nil, DialError{err}
}
- // if read/write timeouts are configured and this is a TCP connection, enforce the timeouts
- // by wrapping the connection with our own type
+ if h.ProxyProtocol != "" {
+ proxyProtocolInfo, ok := caddyhttp.GetVar(ctx, proxyProtocolInfoVarKey).(ProxyProtocolInfo)
+ if !ok {
+ return nil, fmt.Errorf("failed to get proxy protocol info from context")
+ }
+
+ // The src and dst have to be of the some address family. As we don't know the original
+ // dst address (it's kind of impossible to know) and this address is generelly of very
+ // little interest, we just set it to all zeros.
+ var destIP net.IP
+ switch {
+ case proxyProtocolInfo.AddrPort.Addr().Is4():
+ destIP = net.IPv4zero
+ case proxyProtocolInfo.AddrPort.Addr().Is6():
+ destIP = net.IPv6zero
+ default:
+ return nil, fmt.Errorf("unexpected remote addr type in proxy protocol info")
+ }
+
+ // TODO: We should probably migrate away from net.IP to use netip.Addr,
+ // but due to the upstream dependency, we can't do that yet.
+ switch h.ProxyProtocol {
+ case "v1":
+ header := proxyprotocol.HeaderV1{
+ SrcIP: net.IP(proxyProtocolInfo.AddrPort.Addr().AsSlice()),
+ SrcPort: int(proxyProtocolInfo.AddrPort.Port()),
+ DestIP: destIP,
+ DestPort: 0,
+ }
+ caddyCtx.Logger().Debug("sending proxy protocol header v1", zap.Any("header", header))
+ _, err = header.WriteTo(conn)
+ case "v2":
+ header := proxyprotocol.HeaderV2{
+ Command: proxyprotocol.CmdProxy,
+ Src: &net.TCPAddr{IP: net.IP(proxyProtocolInfo.AddrPort.Addr().AsSlice()), Port: int(proxyProtocolInfo.AddrPort.Port())},
+ Dest: &net.TCPAddr{IP: destIP, Port: 0},
+ }
+ caddyCtx.Logger().Debug("sending proxy protocol header v2", zap.Any("header", header))
+ _, err = header.WriteTo(conn)
+ default:
+ return nil, fmt.Errorf("unexpected proxy protocol version")
+ }
+
+ if err != nil {
+ // identify this error as one that occurred during
+ // dialing, which can be important when trying to
+ // decide whether to retry a request
+ return nil, DialError{err}
+ }
+ }
+
+ // if read/write timeouts are configured and this is a TCP connection,
+ // enforce the timeouts by wrapping the connection with our own type
if tcpConn, ok := conn.(*net.TCPConn); ok && (h.ReadTimeout > 0 || h.WriteTimeout > 0) {
conn = &tcpRWTimeoutConn{
TCPConn: tcpConn,
@@ -203,6 +268,7 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
}
rt := &http.Transport{
+ Proxy: http.ProxyFromEnvironment,
DialContext: dialContext,
MaxConnsPerHost: h.MaxConnsPerHost,
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
@@ -231,6 +297,14 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
rt.IdleConnTimeout = time.Duration(h.KeepAlive.IdleConnTimeout)
}
+ // The proxy protocol header can only be sent once right after opening the connection.
+ // So single connection must not be used for multiple requests, which can potentially
+ // come from different clients.
+ if !rt.DisableKeepAlives && h.ProxyProtocol != "" {
+ caddyCtx.Logger().Warn("disabling keepalives, they are incompatible with using PROXY protocol")
+ rt.DisableKeepAlives = true
+ }
+
if h.Compression != nil {
rt.DisableCompression = !*h.Compression
}
@@ -452,10 +526,11 @@ func (t TLSConfig) MakeTLSClientConfig(ctx caddy.Context) (*tls.Config, error) {
return nil, fmt.Errorf("managing client certificate: %v", err)
}
cfg.GetClientCertificate = func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
- certs := tlsApp.AllMatchingCertificates(t.ClientCertificateAutomate)
+ certs := caddytls.AllMatchingCertificates(t.ClientCertificateAutomate)
var err error
for _, cert := range certs {
- err = cri.SupportsCertificate(&cert.Certificate)
+ certCertificate := cert.Certificate // avoid taking address of iteration variable (gosec warning)
+ err = cri.SupportsCertificate(&certCertificate)
if err == nil {
return &cert.Certificate, nil
}
diff --git a/modules/caddyhttp/reverseproxy/metrics.go b/modules/caddyhttp/reverseproxy/metrics.go
index 4272bc4..d3c8ee0 100644
--- a/modules/caddyhttp/reverseproxy/metrics.go
+++ b/modules/caddyhttp/reverseproxy/metrics.go
@@ -39,6 +39,8 @@ func newMetricsUpstreamsHealthyUpdater(handler *Handler) *metricsUpstreamsHealth
initReverseProxyMetrics(handler)
})
+ reverseProxyMetrics.upstreamsHealthy.Reset()
+
return &metricsUpstreamsHealthyUpdater{handler}
}
diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go
index 1449785..08be40d 100644
--- a/modules/caddyhttp/reverseproxy/reverseproxy.go
+++ b/modules/caddyhttp/reverseproxy/reverseproxy.go
@@ -27,30 +27,23 @@ import (
"net/netip"
"net/textproto"
"net/url"
- "regexp"
- "runtime"
"strconv"
"strings"
"sync"
"time"
+ "go.uber.org/zap"
+ "golang.org/x/net/http/httpguts"
+
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"github.com/caddyserver/caddy/v2/modules/caddyevents"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/caddyserver/caddy/v2/modules/caddyhttp/headers"
"github.com/caddyserver/caddy/v2/modules/caddyhttp/rewrite"
- "go.uber.org/zap"
- "golang.org/x/net/http/httpguts"
)
-var supports1xx bool
-
func init() {
- // Caddy requires at least Go 1.18, but Early Hints requires Go 1.19; thus we can simply check for 1.18 in version string
- // TODO: remove this once our minimum Go version is 1.19
- supports1xx = !strings.Contains(runtime.Version(), "go1.18")
-
caddy.RegisterModule(Handler{})
}
@@ -158,6 +151,19 @@ type Handler struct {
// could be useful if the backend has tighter memory constraints.
ResponseBuffers int64 `json:"response_buffers,omitempty"`
+ // If nonzero, streaming requests such as WebSockets will be
+ // forcibly closed at the end of the timeout. Default: no timeout.
+ StreamTimeout caddy.Duration `json:"stream_timeout,omitempty"`
+
+ // If nonzero, streaming requests such as WebSockets will not be
+ // closed when the proxy config is unloaded, and instead the stream
+ // will remain open until the delay is complete. In other words,
+ // enabling this prevents streams from closing when Caddy's config
+ // is reloaded. Enabling this may be a good idea to avoid a thundering
+ // herd of reconnecting clients which had their connections closed
+ // by the previous config closing. Default: no delay.
+ StreamCloseDelay caddy.Duration `json:"stream_close_delay,omitempty"`
+
// If configured, rewrites the copy of the upstream request.
// Allows changing the request method and URI (path and query).
// Since the rewrite is applied to the copy, it does not persist
@@ -185,6 +191,13 @@ type Handler struct {
// - `{http.reverse_proxy.header.*}` The headers from the response
HandleResponse []caddyhttp.ResponseHandler `json:"handle_response,omitempty"`
+ // If set, the proxy will write very detailed logs about its
+ // inner workings. Enable this only when debugging, as it
+ // will produce a lot of output.
+ //
+ // EXPERIMENTAL: This feature is subject to change or removal.
+ VerboseLogs bool `json:"verbose_logs,omitempty"`
+
Transport http.RoundTripper `json:"-"`
CB CircuitBreaker `json:"-"`
DynamicUpstreams UpstreamSource `json:"-"`
@@ -199,8 +212,9 @@ type Handler struct {
handleResponseSegments []*caddyfile.Dispenser
// Stores upgraded requests (hijacked connections) for proper cleanup
- connections map[io.ReadWriteCloser]openConnection
- connectionsMu *sync.Mutex
+ connections map[io.ReadWriteCloser]openConnection
+ connectionsCloseTimer *time.Timer
+ connectionsMu *sync.Mutex
ctx caddy.Context
logger *zap.Logger
@@ -243,20 +257,6 @@ func (h *Handler) Provision(ctx caddy.Context) error {
h.logger.Warn("UNLIMITED BUFFERING: buffering is enabled without any cap on buffer size, which can result in OOM crashes")
}
- // verify SRV compatibility - TODO: LookupSRV deprecated; will be removed
- for i, v := range h.Upstreams {
- if v.LookupSRV == "" {
- continue
- }
- h.logger.Warn("DEPRECATED: lookup_srv: will be removed in a near-future version of Caddy; use the http.reverse_proxy.upstreams.srv module instead")
- if h.HealthChecks != nil && h.HealthChecks.Active != nil {
- return fmt.Errorf(`upstream: lookup_srv is incompatible with active health checks: %d: {"dial": %q, "lookup_srv": %q}`, i, v.Dial, v.LookupSRV)
- }
- if v.Dial != "" {
- return fmt.Errorf(`upstream: specifying dial address is incompatible with lookup_srv: %d: {"dial": %q, "lookup_srv": %q}`, i, v.Dial, v.LookupSRV)
- }
- }
-
// start by loading modules
if h.TransportRaw != nil {
mod, err := ctx.LoadModule(h, "TransportRaw")
@@ -363,62 +363,22 @@ func (h *Handler) Provision(ctx caddy.Context) error {
if h.HealthChecks != nil {
// set defaults on passive health checks, if necessary
if h.HealthChecks.Passive != nil {
- if h.HealthChecks.Passive.FailDuration > 0 && h.HealthChecks.Passive.MaxFails == 0 {
+ h.HealthChecks.Passive.logger = h.logger.Named("health_checker.passive")
+ if h.HealthChecks.Passive.MaxFails == 0 {
h.HealthChecks.Passive.MaxFails = 1
}
}
// if active health checks are enabled, configure them and start a worker
- if h.HealthChecks.Active != nil && (h.HealthChecks.Active.Path != "" ||
- h.HealthChecks.Active.URI != "" ||
- h.HealthChecks.Active.Port != 0) {
-
- h.HealthChecks.Active.logger = h.logger.Named("health_checker.active")
-
- timeout := time.Duration(h.HealthChecks.Active.Timeout)
- if timeout == 0 {
- timeout = 5 * time.Second
- }
-
- if h.HealthChecks.Active.Path != "" {
- h.HealthChecks.Active.logger.Warn("the 'path' option is deprecated, please use 'uri' instead!")
- }
-
- // parse the URI string (supports path and query)
- if h.HealthChecks.Active.URI != "" {
- parsedURI, err := url.Parse(h.HealthChecks.Active.URI)
- if err != nil {
- return err
- }
- h.HealthChecks.Active.uri = parsedURI
- }
-
- h.HealthChecks.Active.httpClient = &http.Client{
- Timeout: timeout,
- Transport: h.Transport,
- }
-
- for _, upstream := range h.Upstreams {
- // if there's an alternative port for health-check provided in the config,
- // then use it, otherwise use the port of upstream.
- if h.HealthChecks.Active.Port != 0 {
- upstream.activeHealthCheckPort = h.HealthChecks.Active.Port
- }
+ if h.HealthChecks.Active != nil {
+ err := h.HealthChecks.Active.Provision(ctx, h)
+ if err != nil {
+ return err
}
- if h.HealthChecks.Active.Interval == 0 {
- h.HealthChecks.Active.Interval = caddy.Duration(30 * time.Second)
+ if h.HealthChecks.Active.IsEnabled() {
+ go h.activeHealthChecker()
}
-
- if h.HealthChecks.Active.ExpectBody != "" {
- var err error
- h.HealthChecks.Active.bodyRegexp, err = regexp.Compile(h.HealthChecks.Active.ExpectBody)
- if err != nil {
- return fmt.Errorf("expect_body: compiling regular expression: %v", err)
- }
- }
-
- go h.activeHealthChecker()
}
}
@@ -438,25 +398,7 @@ func (h *Handler) Provision(ctx caddy.Context) error {
// Cleanup cleans up the resources made by h.
func (h *Handler) Cleanup() error {
- // close hijacked connections (both to client and backend)
- var err error
- h.connectionsMu.Lock()
- for _, oc := range h.connections {
- if oc.gracefulClose != nil {
- // this is potentially blocking while we have the lock on the connections
- // map, but that should be OK since the server has in theory shut down
- // and we are no longer using the connections map
- gracefulErr := oc.gracefulClose()
- if gracefulErr != nil && err == nil {
- err = gracefulErr
- }
- }
- closeErr := oc.conn.Close()
- if closeErr != nil && err == nil {
- err = closeErr
- }
- }
- h.connectionsMu.Unlock()
+ err := h.cleanupConnections()
// remove hosts from our config from the pool
for _, upstream := range h.Upstreams {
@@ -517,7 +459,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
// It returns true when the loop is done and should break; false otherwise. The error value returned should
// be assigned to the proxyErr value for the next iteration of the loop (or the error handled after break).
func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w http.ResponseWriter, proxyErr error, start time.Time, retries int,
- repl *caddy.Replacer, reqHeader http.Header, reqHost string, next caddyhttp.Handler) (bool, error) {
+ repl *caddy.Replacer, reqHeader http.Header, reqHost string, next caddyhttp.Handler,
+) (bool, error) {
// get the updated list of upstreams
upstreams := h.Upstreams
if h.DynamicUpstreams != nil {
@@ -544,7 +487,7 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h
upstream := h.LoadBalancing.SelectionPolicy.Select(upstreams, r, w)
if upstream == nil {
if proxyErr == nil {
- proxyErr = caddyhttp.Error(http.StatusServiceUnavailable, fmt.Errorf("no upstreams available"))
+ proxyErr = caddyhttp.Error(http.StatusServiceUnavailable, noUpstreamsAvailable)
}
if !h.LoadBalancing.tryAgain(h.ctx, start, retries, proxyErr, r) {
return true, proxyErr
@@ -646,7 +589,8 @@ func (h Handler) prepareRequest(req *http.Request, repl *caddy.Replacer) (*http.
// feature if absolutely required, if read timeouts are
// set, and if body size is limited
if h.RequestBuffers != 0 && req.Body != nil {
- req.Body, _ = h.bufferedBody(req.Body, h.RequestBuffers)
+ req.Body, req.ContentLength = h.bufferedBody(req.Body, h.RequestBuffers)
+ req.Header.Set("Content-Length", strconv.FormatInt(req.ContentLength, 10))
}
if req.ContentLength == 0 {
@@ -687,8 +631,24 @@ func (h Handler) prepareRequest(req *http.Request, repl *caddy.Replacer) (*http.
req.Header.Set("Upgrade", reqUpType)
}
+ // Set up the PROXY protocol info
+ address := caddyhttp.GetVar(req.Context(), caddyhttp.ClientIPVarKey).(string)
+ addrPort, err := netip.ParseAddrPort(address)
+ if err != nil {
+ // OK; probably didn't have a port
+ addr, err := netip.ParseAddr(address)
+ if err != nil {
+ // Doesn't seem like a valid ip address at all
+ } else {
+ // Ok, only the port was missing
+ addrPort = netip.AddrPortFrom(addr, 0)
+ }
+ }
+ proxyProtocolInfo := ProxyProtocolInfo{AddrPort: addrPort}
+ caddyhttp.SetVar(req.Context(), proxyProtocolInfoVarKey, proxyProtocolInfo)
+
// Add the supported X-Forwarded-* headers
- err := h.addForwardedHeaders(req)
+ err = h.addForwardedHeaders(req)
if err != nil {
return nil, err
}
@@ -795,25 +755,23 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
server := req.Context().Value(caddyhttp.ServerCtxKey).(*caddyhttp.Server)
shouldLogCredentials := server.Logs != nil && server.Logs.ShouldLogCredentials
- if supports1xx {
- // Forward 1xx status codes, backported from https://github.com/golang/go/pull/53164
- trace := &httptrace.ClientTrace{
- Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
- h := rw.Header()
- copyHeader(h, http.Header(header))
- rw.WriteHeader(code)
-
- // Clear headers coming from the backend
- // (it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses)
- for k := range header {
- delete(h, k)
- }
+ // Forward 1xx status codes, backported from https://github.com/golang/go/pull/53164
+ trace := &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ h := rw.Header()
+ copyHeader(h, http.Header(header))
+ rw.WriteHeader(code)
+
+ // Clear headers coming from the backend
+ // (it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses)
+ for k := range header {
+ delete(h, k)
+ }
- return nil
- },
- }
- req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+ return nil
+ },
}
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
// if FlushInterval is explicitly configured to -1 (i.e. flush continuously to achieve
// low-latency streaming), don't let the transport cancel the request if the client
@@ -821,7 +779,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
// regardless, and we should expect client disconnection in low-latency streaming
// scenarios (see issue #4922)
if h.FlushInterval == -1 {
- req = req.WithContext(ignoreClientGoneContext{req.Context(), h.ctx.Done()})
+ req = req.WithContext(ignoreClientGoneContext{req.Context()})
}
// do the round-trip; emit debug log with values we know are
@@ -897,12 +855,6 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
break
}
- // otherwise, if there are any routes configured, execute those as the
- // actual response instead of what we got from the proxy backend
- if len(rh.Routes) == 0 {
- continue
- }
-
// set up the replacer so that parts of the original response can be
// used for routing decisions
for field, value := range res.Header {
@@ -911,7 +863,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
repl.Set("http.reverse_proxy.status_code", res.StatusCode)
repl.Set("http.reverse_proxy.status_text", res.Status)
- h.logger.Debug("handling response", zap.Int("handler", i))
+ logger.Debug("handling response", zap.Int("handler", i))
// we make some data available via request context to child routes
// so that they may inherit some options and functions from the
@@ -956,7 +908,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
}
// finalizeResponse prepares and copies the response.
-func (h Handler) finalizeResponse(
+func (h *Handler) finalizeResponse(
rw http.ResponseWriter,
req *http.Request,
res *http.Response,
@@ -998,15 +950,21 @@ func (h Handler) finalizeResponse(
}
rw.WriteHeader(res.StatusCode)
+ if h.VerboseLogs {
+ logger.Debug("wrote header")
+ }
- err := h.copyResponse(rw, res.Body, h.flushInterval(req, res))
- res.Body.Close() // close now, instead of defer, to populate res.Trailer
+ err := h.copyResponse(rw, res.Body, h.flushInterval(req, res), logger)
+ errClose := res.Body.Close() // close now, instead of defer, to populate res.Trailer
+ if h.VerboseLogs || errClose != nil {
+ logger.Debug("closed response body from upstream", zap.Error(errClose))
+ }
if err != nil {
// we're streaming the response and we've already written headers, so
// there's nothing an error handler can do to recover at this point;
// the standard lib's proxy panics at this point, but we'll just log
// the error and abort the stream here
- h.logger.Error("aborting with incomplete response", zap.Error(err))
+ logger.Error("aborting with incomplete response", zap.Error(err))
return nil
}
@@ -1014,9 +972,8 @@ func (h Handler) finalizeResponse(
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
- if fl, ok := rw.(http.Flusher); ok {
- fl.Flush()
- }
+ //nolint:bodyclose
+ http.NewResponseController(rw).Flush()
}
// total duration spent proxying, including writing response body
@@ -1035,6 +992,10 @@ func (h Handler) finalizeResponse(
}
}
+ if h.VerboseLogs {
+ logger.Debug("response finalized")
+ }
+
return nil
}
@@ -1066,17 +1027,23 @@ func (lb LoadBalancing) tryAgain(ctx caddy.Context, start time.Time, retries int
// should be safe to retry, since without a connection, no
// HTTP request can be transmitted; but if the error is not
// specifically a dialer error, we need to be careful
- if _, ok := proxyErr.(DialError); proxyErr != nil && !ok {
+ if proxyErr != nil {
+ _, isDialError := proxyErr.(DialError)
+ herr, isHandlerError := proxyErr.(caddyhttp.HandlerError)
+
// if the error occurred after a connection was established,
// we have to assume the upstream received the request, and
// retries need to be carefully decided, because some requests
// are not idempotent
- if lb.RetryMatch == nil && req.Method != "GET" {
- // by default, don't retry requests if they aren't GET
- return false
- }
- if !lb.RetryMatch.AnyMatch(req) {
- return false
+ if !isDialError && !(isHandlerError && errors.Is(herr, noUpstreamsAvailable)) {
+ if lb.RetryMatch == nil && req.Method != "GET" {
+ // by default, don't retry requests if they aren't GET
+ return false
+ }
+
+ if !lb.RetryMatch.AnyMatch(req) {
+ return false
+ }
}
}
@@ -1128,12 +1095,11 @@ func (h Handler) provisionUpstream(upstream *Upstream) {
// without MaxRequests), copy the value into this upstream, since the
// value in the upstream (MaxRequests) is what is used during
// availability checks
- if h.HealthChecks != nil && h.HealthChecks.Passive != nil {
- h.HealthChecks.Passive.logger = h.logger.Named("health_checker.passive")
- if h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
- upstream.MaxRequests == 0 {
- upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
- }
+ if h.HealthChecks != nil &&
+ h.HealthChecks.Passive != nil &&
+ h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
+ upstream.MaxRequests == 0 {
+ upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
}
// upstreams need independent access to the passive
@@ -1450,21 +1416,36 @@ type handleResponseContext struct {
// ignoreClientGoneContext is a special context.Context type
// intended for use when doing a RoundTrip where you don't
// want a client disconnection to cancel the request during
-// the roundtrip. Set its done field to a Done() channel
-// of a context that doesn't get canceled when the client
-// disconnects, such as caddy.Context.Done() instead.
+// the roundtrip.
+// This context clears cancellation, error, and deadline methods,
+// but still allows values to pass through from its embedded
+// context.
+//
+// TODO: This can be replaced with context.WithoutCancel once
+// the minimum required version of Go is 1.21.
type ignoreClientGoneContext struct {
context.Context
- done <-chan struct{}
}
-func (c ignoreClientGoneContext) Done() <-chan struct{} { return c.done }
+func (c ignoreClientGoneContext) Deadline() (deadline time.Time, ok bool) {
+ return
+}
+
+func (c ignoreClientGoneContext) Done() <-chan struct{} {
+ return nil
+}
+
+func (c ignoreClientGoneContext) Err() error {
+ return nil
+}
// proxyHandleResponseContextCtxKey is the context key for the active proxy handler
// so that handle_response routes can inherit some config options
// from the proxy handler.
const proxyHandleResponseContextCtxKey caddy.CtxKey = "reverse_proxy_handle_response_context"
+var noUpstreamsAvailable = fmt.Errorf("no upstreams available")
+
// Interface guards
var (
_ caddy.Provisioner = (*Handler)(nil)
diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go
index 0b7f50c..acb069a 100644
--- a/modules/caddyhttp/reverseproxy/selectionpolicies.go
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go
@@ -18,17 +18,20 @@ import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
+ "encoding/json"
"fmt"
"hash/fnv"
weakrand "math/rand"
"net"
"net/http"
"strconv"
+ "strings"
"sync/atomic"
- "time"
"github.com/caddyserver/caddy/v2"
+ "github.com/caddyserver/caddy/v2/caddyconfig"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
+ "github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
func init() {
@@ -36,13 +39,14 @@ func init() {
caddy.RegisterModule(RandomChoiceSelection{})
caddy.RegisterModule(LeastConnSelection{})
caddy.RegisterModule(RoundRobinSelection{})
+ caddy.RegisterModule(WeightedRoundRobinSelection{})
caddy.RegisterModule(FirstSelection{})
caddy.RegisterModule(IPHashSelection{})
+ caddy.RegisterModule(ClientIPHashSelection{})
caddy.RegisterModule(URIHashSelection{})
+ caddy.RegisterModule(QueryHashSelection{})
caddy.RegisterModule(HeaderHashSelection{})
caddy.RegisterModule(CookieHashSelection{})
-
- weakrand.Seed(time.Now().UTC().UnixNano())
}
// RandomSelection is a policy that selects
@@ -72,6 +76,90 @@ func (r *RandomSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return nil
}
+// WeightedRoundRobinSelection is a policy that selects
+// a host based on weighted round-robin ordering.
+type WeightedRoundRobinSelection struct {
+ // The weight of each upstream in order,
+ // corresponding with the list of upstreams configured.
+ Weights []int `json:"weights,omitempty"`
+ index uint32
+ totalWeight int
+}
+
+// CaddyModule returns the Caddy module information.
+func (WeightedRoundRobinSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ ID: "http.reverse_proxy.selection_policies.weighted_round_robin",
+ New: func() caddy.Module {
+ return new(WeightedRoundRobinSelection)
+ },
+ }
+}
+
+// UnmarshalCaddyfile sets up the module from Caddyfile tokens.
+func (r *WeightedRoundRobinSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
+ for d.Next() {
+ args := d.RemainingArgs()
+ if len(args) == 0 {
+ return d.ArgErr()
+ }
+
+ for _, weight := range args {
+ weightInt, err := strconv.Atoi(weight)
+ if err != nil {
+ return d.Errf("invalid weight value '%s': %v", weight, err)
+ }
+ if weightInt < 1 {
+ return d.Errf("invalid weight value '%s': weight should be non-zero and positive", weight)
+ }
+ r.Weights = append(r.Weights, weightInt)
+ }
+ }
+ return nil
+}
+
+// Provision sets up r.
+func (r *WeightedRoundRobinSelection) Provision(ctx caddy.Context) error {
+ for _, weight := range r.Weights {
+ r.totalWeight += weight
+ }
+ return nil
+}
+
+// Select returns an available host, if any.
+func (r *WeightedRoundRobinSelection) Select(pool UpstreamPool, _ *http.Request, _ http.ResponseWriter) *Upstream {
+ if len(pool) == 0 {
+ return nil
+ }
+ if len(r.Weights) < 2 {
+ return pool[0]
+ }
+ var index, totalWeight int
+ currentWeight := int(atomic.AddUint32(&r.index, 1)) % r.totalWeight
+ for i, weight := range r.Weights {
+ totalWeight += weight
+ if currentWeight < totalWeight {
+ index = i
+ break
+ }
+ }
+
+ upstreams := make([]*Upstream, 0, len(r.Weights))
+ for _, upstream := range pool {
+ if !upstream.Available() {
+ continue
+ }
+ upstreams = append(upstreams, upstream)
+ if len(upstreams) == cap(upstreams) {
+ break
+ }
+ }
+ if len(upstreams) == 0 {
+ return nil
+ }
+ return upstreams[index%len(upstreams)]
+}
+
// RandomChoiceSelection is a policy that selects
// two or more available hosts at random, then
// chooses the one with the least load.
@@ -181,7 +269,7 @@ func (LeastConnSelection) Select(pool UpstreamPool, _ *http.Request, _ http.Resp
// sample: https://en.wikipedia.org/wiki/Reservoir_sampling
if numReqs == leastReqs {
count++
- if (weakrand.Int() % count) == 0 { //nolint:gosec
+ if count == 1 || (weakrand.Int()%count) == 0 { //nolint:gosec
bestHost = host
}
}
@@ -303,6 +391,39 @@ func (r *IPHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return nil
}
+// ClientIPHashSelection is a policy that selects a host
+// based on hashing the client IP of the request, as determined
+// by the HTTP app's trusted proxies settings.
+type ClientIPHashSelection struct{}
+
+// CaddyModule returns the Caddy module information.
+func (ClientIPHashSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ ID: "http.reverse_proxy.selection_policies.client_ip_hash",
+ New: func() caddy.Module { return new(ClientIPHashSelection) },
+ }
+}
+
+// Select returns an available host, if any.
+func (ClientIPHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
+ address := caddyhttp.GetVar(req.Context(), caddyhttp.ClientIPVarKey).(string)
+ clientIP, _, err := net.SplitHostPort(address)
+ if err != nil {
+ clientIP = address // no port
+ }
+ return hostByHashing(pool, clientIP)
+}
+
+// UnmarshalCaddyfile sets up the module from Caddyfile tokens.
+func (r *ClientIPHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
+ for d.Next() {
+ if d.NextArg() {
+ return d.ArgErr()
+ }
+ }
+ return nil
+}
+
// URIHashSelection is a policy that selects a
// host by hashing the request URI.
type URIHashSelection struct{}
@@ -330,11 +451,95 @@ func (r *URIHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return nil
}
+// QueryHashSelection is a policy that selects
+// a host based on a given request query parameter.
+type QueryHashSelection struct {
+ // The query key whose value is to be hashed and used for upstream selection.
+ Key string `json:"key,omitempty"`
+
+ // The fallback policy to use if the query key is not present. Defaults to `random`.
+ FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"`
+ fallback Selector
+}
+
+// CaddyModule returns the Caddy module information.
+func (QueryHashSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ ID: "http.reverse_proxy.selection_policies.query",
+ New: func() caddy.Module { return new(QueryHashSelection) },
+ }
+}
+
+// Provision sets up the module.
+func (s *QueryHashSelection) Provision(ctx caddy.Context) error {
+ if s.Key == "" {
+ return fmt.Errorf("query key is required")
+ }
+ if s.FallbackRaw == nil {
+ s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil)
+ }
+ mod, err := ctx.LoadModule(s, "FallbackRaw")
+ if err != nil {
+ return fmt.Errorf("loading fallback selection policy: %s", err)
+ }
+ s.fallback = mod.(Selector)
+ return nil
+}
+
+// Select returns an available host, if any.
+func (s QueryHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
+ // Since the query may have multiple values for the same key,
+ // we'll join them to avoid a problem where the user can control
+ // the upstream that the request goes to by sending multiple values
+ // for the same key, when the upstream only considers the first value.
+ // Keep in mind that a client changing the order of the values may
+ // affect which upstream is selected, but this is a semantically
+ // different request, because the order of the values is significant.
+ vals := strings.Join(req.URL.Query()[s.Key], ",")
+ if vals == "" {
+ return s.fallback.Select(pool, req, nil)
+ }
+ return hostByHashing(pool, vals)
+}
+
+// UnmarshalCaddyfile sets up the module from Caddyfile tokens.
+func (s *QueryHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
+ for d.Next() {
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ s.Key = d.Val()
+ }
+ for nesting := d.Nesting(); d.NextBlock(nesting); {
+ switch d.Val() {
+ case "fallback":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if s.FallbackRaw != nil {
+ return d.Err("fallback selection policy already specified")
+ }
+ mod, err := loadFallbackPolicy(d)
+ if err != nil {
+ return err
+ }
+ s.FallbackRaw = mod
+ default:
+ return d.Errf("unrecognized option '%s'", d.Val())
+ }
+ }
+ return nil
+}
+
// HeaderHashSelection is a policy that selects
// a host based on a given request header.
type HeaderHashSelection struct {
// The HTTP header field whose value is to be hashed and used for upstream selection.
Field string `json:"field,omitempty"`
+
+ // The fallback policy to use if the header is not present. Defaults to `random`.
+ FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"`
+ fallback Selector
}
// CaddyModule returns the Caddy module information.
@@ -345,12 +550,24 @@ func (HeaderHashSelection) CaddyModule() caddy.ModuleInfo {
}
}
-// Select returns an available host, if any.
-func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
+// Provision sets up the module.
+func (s *HeaderHashSelection) Provision(ctx caddy.Context) error {
if s.Field == "" {
- return nil
+ return fmt.Errorf("header field is required")
+ }
+ if s.FallbackRaw == nil {
+ s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil)
}
+ mod, err := ctx.LoadModule(s, "FallbackRaw")
+ if err != nil {
+ return fmt.Errorf("loading fallback selection policy: %s", err)
+ }
+ s.fallback = mod.(Selector)
+ return nil
+}
+// Select returns an available host, if any.
+func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
// The Host header should be obtained from the req.Host field
// since net/http removes it from the header map.
if s.Field == "Host" && req.Host != "" {
@@ -359,7 +576,7 @@ func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http
val := req.Header.Get(s.Field)
if val == "" {
- return RandomSelection{}.Select(pool, req, nil)
+ return s.fallback.Select(pool, req, nil)
}
return hostByHashing(pool, val)
}
@@ -372,6 +589,24 @@ func (s *HeaderHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
s.Field = d.Val()
}
+ for nesting := d.Nesting(); d.NextBlock(nesting); {
+ switch d.Val() {
+ case "fallback":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if s.FallbackRaw != nil {
+ return d.Err("fallback selection policy already specified")
+ }
+ mod, err := loadFallbackPolicy(d)
+ if err != nil {
+ return err
+ }
+ s.FallbackRaw = mod
+ default:
+ return d.Errf("unrecognized option '%s'", d.Val())
+ }
+ }
return nil
}
@@ -382,6 +617,10 @@ type CookieHashSelection struct {
Name string `json:"name,omitempty"`
// Secret to hash (Hmac256) chosen upstream in cookie
Secret string `json:"secret,omitempty"`
+
+ // The fallback policy to use if the cookie is not present. Defaults to `random`.
+ FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"`
+ fallback Selector
}
// CaddyModule returns the Caddy module information.
@@ -392,15 +631,48 @@ func (CookieHashSelection) CaddyModule() caddy.ModuleInfo {
}
}
-// Select returns an available host, if any.
-func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http.ResponseWriter) *Upstream {
+// Provision sets up the module.
+func (s *CookieHashSelection) Provision(ctx caddy.Context) error {
if s.Name == "" {
s.Name = "lb"
}
+ if s.FallbackRaw == nil {
+ s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil)
+ }
+ mod, err := ctx.LoadModule(s, "FallbackRaw")
+ if err != nil {
+ return fmt.Errorf("loading fallback selection policy: %s", err)
+ }
+ s.fallback = mod.(Selector)
+ return nil
+}
+
+// Select returns an available host, if any.
+func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http.ResponseWriter) *Upstream {
+ // selects a new Host using the fallback policy (typically random)
+ // and write a sticky session cookie to the response.
+ selectNewHost := func() *Upstream {
+ upstream := s.fallback.Select(pool, req, w)
+ if upstream == nil {
+ return nil
+ }
+ sha, err := hashCookie(s.Secret, upstream.Dial)
+ if err != nil {
+ return upstream
+ }
+ http.SetCookie(w, &http.Cookie{
+ Name: s.Name,
+ Value: sha,
+ Path: "/",
+ Secure: false,
+ })
+ return upstream
+ }
+
cookie, err := req.Cookie(s.Name)
- // If there's no cookie, select new random host
+ // If there's no cookie, select a host using the fallback policy
if err != nil || cookie == nil {
- return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name)
+ return selectNewHost()
}
// If the cookie is present, loop over the available upstreams until we find a match
cookieValue := cookie.Value
@@ -413,13 +685,15 @@ func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http
return upstream
}
}
- // If there is no matching host, select new random host
- return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name)
+ // If there is no matching host, select a host using the fallback policy
+ return selectNewHost()
}
// UnmarshalCaddyfile sets up the module from Caddyfile tokens. Syntax:
//
-// lb_policy cookie [<name> [<secret>]]
+// lb_policy cookie [<name> [<secret>]] {
+// fallback <policy>
+// }
//
// By default name is `lb`
func (s *CookieHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
@@ -434,22 +708,25 @@ func (s *CookieHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
default:
return d.ArgErr()
}
- return nil
-}
-
-// Select a new Host randomly and add a sticky session cookie
-func selectNewHostWithCookieHashSelection(pool []*Upstream, w http.ResponseWriter, cookieSecret string, cookieName string) *Upstream {
- randomHost := selectRandomHost(pool)
-
- if randomHost != nil {
- // Hash (HMAC with some key for privacy) the upstream.Dial string as the cookie value
- sha, err := hashCookie(cookieSecret, randomHost.Dial)
- if err == nil {
- // write the cookie.
- http.SetCookie(w, &http.Cookie{Name: cookieName, Value: sha, Path: "/", Secure: false})
+ for nesting := d.Nesting(); d.NextBlock(nesting); {
+ switch d.Val() {
+ case "fallback":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if s.FallbackRaw != nil {
+ return d.Err("fallback selection policy already specified")
+ }
+ mod, err := loadFallbackPolicy(d)
+ if err != nil {
+ return err
+ }
+ s.FallbackRaw = mod
+ default:
+ return d.Errf("unrecognized option '%s'", d.Val())
}
}
- return randomHost
+ return nil
}
// hashCookie hashes (HMAC 256) some data with the secret
@@ -512,6 +789,9 @@ func leastRequests(upstreams []*Upstream) *Upstream {
if len(best) == 0 {
return nil
}
+ if len(best) == 1 {
+ return best[0]
+ }
return best[weakrand.Intn(len(best))] //nolint:gosec
}
@@ -544,20 +824,40 @@ func hash(s string) uint32 {
return h.Sum32()
}
+func loadFallbackPolicy(d *caddyfile.Dispenser) (json.RawMessage, error) {
+ name := d.Val()
+ modID := "http.reverse_proxy.selection_policies." + name
+ unm, err := caddyfile.UnmarshalModule(d, modID)
+ if err != nil {
+ return nil, err
+ }
+ sel, ok := unm.(Selector)
+ if !ok {
+ return nil, d.Errf("module %s (%T) is not a reverseproxy.Selector", modID, unm)
+ }
+ return caddyconfig.JSONModuleObject(sel, "policy", name, nil), nil
+}
+
// Interface guards
var (
_ Selector = (*RandomSelection)(nil)
_ Selector = (*RandomChoiceSelection)(nil)
_ Selector = (*LeastConnSelection)(nil)
_ Selector = (*RoundRobinSelection)(nil)
+ _ Selector = (*WeightedRoundRobinSelection)(nil)
_ Selector = (*FirstSelection)(nil)
_ Selector = (*IPHashSelection)(nil)
+ _ Selector = (*ClientIPHashSelection)(nil)
_ Selector = (*URIHashSelection)(nil)
+ _ Selector = (*QueryHashSelection)(nil)
_ Selector = (*HeaderHashSelection)(nil)
_ Selector = (*CookieHashSelection)(nil)
- _ caddy.Validator = (*RandomChoiceSelection)(nil)
+ _ caddy.Validator = (*RandomChoiceSelection)(nil)
+
_ caddy.Provisioner = (*RandomChoiceSelection)(nil)
+ _ caddy.Provisioner = (*WeightedRoundRobinSelection)(nil)
_ caddyfile.Unmarshaler = (*RandomChoiceSelection)(nil)
+ _ caddyfile.Unmarshaler = (*WeightedRoundRobinSelection)(nil)
)
diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
index 546a60d..dc613a5 100644
--- a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
@@ -15,9 +15,14 @@
package reverseproxy
import (
+ "context"
"net/http"
"net/http/httptest"
"testing"
+
+ "github.com/caddyserver/caddy/v2"
+ "github.com/caddyserver/caddy/v2/caddyconfig"
+ "github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
func testPool() UpstreamPool {
@@ -30,7 +35,7 @@ func testPool() UpstreamPool {
func TestRoundRobinPolicy(t *testing.T) {
pool := testPool()
- rrPolicy := new(RoundRobinSelection)
+ rrPolicy := RoundRobinSelection{}
req, _ := http.NewRequest("GET", "/", nil)
h := rrPolicy.Select(pool, req, nil)
@@ -69,9 +74,66 @@ func TestRoundRobinPolicy(t *testing.T) {
}
}
+func TestWeightedRoundRobinPolicy(t *testing.T) {
+ pool := testPool()
+ wrrPolicy := WeightedRoundRobinSelection{
+ Weights: []int{3, 2, 1},
+ totalWeight: 6,
+ }
+ req, _ := http.NewRequest("GET", "/", nil)
+
+ h := wrrPolicy.Select(pool, req, nil)
+ if h != pool[0] {
+ t.Error("Expected first weighted round robin host to be first host in the pool.")
+ }
+ h = wrrPolicy.Select(pool, req, nil)
+ if h != pool[0] {
+ t.Error("Expected second weighted round robin host to be first host in the pool.")
+ }
+ // Third selected host is 1, because counter starts at 0
+ // and increments before host is selected
+ h = wrrPolicy.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected third weighted round robin host to be second host in the pool.")
+ }
+ h = wrrPolicy.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected fourth weighted round robin host to be second host in the pool.")
+ }
+ h = wrrPolicy.Select(pool, req, nil)
+ if h != pool[2] {
+ t.Error("Expected fifth weighted round robin host to be third host in the pool.")
+ }
+ h = wrrPolicy.Select(pool, req, nil)
+ if h != pool[0] {
+ t.Error("Expected sixth weighted round robin host to be first host in the pool.")
+ }
+
+ // mark host as down
+ pool[0].setHealthy(false)
+ h = wrrPolicy.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected to skip down host.")
+ }
+ // mark host as up
+ pool[0].setHealthy(true)
+
+ h = wrrPolicy.Select(pool, req, nil)
+ if h != pool[0] {
+ t.Error("Expected to select first host on availablity.")
+ }
+ // mark host as full
+ pool[1].countRequest(1)
+ pool[1].MaxRequests = 1
+ h = wrrPolicy.Select(pool, req, nil)
+ if h != pool[2] {
+ t.Error("Expected to skip full host.")
+ }
+}
+
func TestLeastConnPolicy(t *testing.T) {
pool := testPool()
- lcPolicy := new(LeastConnSelection)
+ lcPolicy := LeastConnSelection{}
req, _ := http.NewRequest("GET", "/", nil)
pool[0].countRequest(10)
@@ -89,7 +151,7 @@ func TestLeastConnPolicy(t *testing.T) {
func TestIPHashPolicy(t *testing.T) {
pool := testPool()
- ipHash := new(IPHashSelection)
+ ipHash := IPHashSelection{}
req, _ := http.NewRequest("GET", "/", nil)
// We should be able to predict where every request is routed.
@@ -229,9 +291,152 @@ func TestIPHashPolicy(t *testing.T) {
}
}
+func TestClientIPHashPolicy(t *testing.T) {
+ pool := testPool()
+ ipHash := ClientIPHashSelection{}
+ req, _ := http.NewRequest("GET", "/", nil)
+ req = req.WithContext(context.WithValue(req.Context(), caddyhttp.VarsCtxKey, make(map[string]any)))
+
+ // We should be able to predict where every request is routed.
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.1:80")
+ h := ipHash.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.2:80")
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.3:80")
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4:80")
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+
+ // we should get the same results without a port
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.1")
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.2")
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.3")
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4")
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+
+ // we should get a healthy host if the original host is unhealthy and a
+ // healthy host is available
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4")
+ pool[1].setHealthy(false)
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[0] {
+ t.Error("Expected ip hash policy host to be the first host.")
+ }
+
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.2")
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[0] {
+ t.Error("Expected ip hash policy host to be the first host.")
+ }
+ pool[1].setHealthy(true)
+
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.3")
+ pool[2].setHealthy(false)
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4")
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+
+ // We should be able to resize the host pool and still be able to predict
+ // where a req will be routed with the same IP's used above
+ pool = UpstreamPool{
+ {Host: new(Host), Dial: "0.0.0.2"},
+ {Host: new(Host), Dial: "0.0.0.3"},
+ }
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.1:80")
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[0] {
+ t.Error("Expected ip hash policy host to be the first host.")
+ }
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.2:80")
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[0] {
+ t.Error("Expected ip hash policy host to be the first host.")
+ }
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.3:80")
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[0] {
+ t.Error("Expected ip hash policy host to be the first host.")
+ }
+ caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4:80")
+ h = ipHash.Select(pool, req, nil)
+ if h != pool[0] {
+ t.Error("Expected ip hash policy host to be the first host.")
+ }
+
+ // We should get nil when there are no healthy hosts
+ pool[0].setHealthy(false)
+ pool[1].setHealthy(false)
+ h = ipHash.Select(pool, req, nil)
+ if h != nil {
+ t.Error("Expected ip hash policy host to be nil.")
+ }
+
+ // Reproduce #4135
+ pool = UpstreamPool{
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ }
+ pool[0].setHealthy(false)
+ pool[1].setHealthy(false)
+ pool[2].setHealthy(false)
+ pool[3].setHealthy(false)
+ pool[4].setHealthy(false)
+ pool[5].setHealthy(false)
+ pool[6].setHealthy(false)
+ pool[7].setHealthy(false)
+ pool[8].setHealthy(true)
+
+ // We should get a result back when there is one healthy host left.
+ h = ipHash.Select(pool, req, nil)
+ if h == nil {
+ // If it is nil, it means we missed a host even though one is available
+ t.Error("Expected ip hash policy host to not be nil, but it is nil.")
+ }
+}
+
func TestFirstPolicy(t *testing.T) {
pool := testPool()
- firstPolicy := new(FirstSelection)
+ firstPolicy := FirstSelection{}
req := httptest.NewRequest(http.MethodGet, "/", nil)
h := firstPolicy.Select(pool, req, nil)
@@ -246,9 +451,85 @@ func TestFirstPolicy(t *testing.T) {
}
}
+func TestQueryHashPolicy(t *testing.T) {
+ ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
+ defer cancel()
+ queryPolicy := QueryHashSelection{Key: "foo"}
+ if err := queryPolicy.Provision(ctx); err != nil {
+ t.Errorf("Provision error: %v", err)
+ t.FailNow()
+ }
+
+ pool := testPool()
+
+ request := httptest.NewRequest(http.MethodGet, "/?foo=1", nil)
+ h := queryPolicy.Select(pool, request, nil)
+ if h != pool[0] {
+ t.Error("Expected query policy host to be the first host.")
+ }
+
+ request = httptest.NewRequest(http.MethodGet, "/?foo=100000", nil)
+ h = queryPolicy.Select(pool, request, nil)
+ if h != pool[0] {
+ t.Error("Expected query policy host to be the first host.")
+ }
+
+ request = httptest.NewRequest(http.MethodGet, "/?foo=1", nil)
+ pool[0].setHealthy(false)
+ h = queryPolicy.Select(pool, request, nil)
+ if h != pool[1] {
+ t.Error("Expected query policy host to be the second host.")
+ }
+
+ request = httptest.NewRequest(http.MethodGet, "/?foo=100000", nil)
+ h = queryPolicy.Select(pool, request, nil)
+ if h != pool[2] {
+ t.Error("Expected query policy host to be the third host.")
+ }
+
+ // We should be able to resize the host pool and still be able to predict
+ // where a request will be routed with the same query used above
+ pool = UpstreamPool{
+ {Host: new(Host)},
+ {Host: new(Host)},
+ }
+
+ request = httptest.NewRequest(http.MethodGet, "/?foo=1", nil)
+ h = queryPolicy.Select(pool, request, nil)
+ if h != pool[0] {
+ t.Error("Expected query policy host to be the first host.")
+ }
+
+ pool[0].setHealthy(false)
+ h = queryPolicy.Select(pool, request, nil)
+ if h != pool[1] {
+ t.Error("Expected query policy host to be the second host.")
+ }
+
+ request = httptest.NewRequest(http.MethodGet, "/?foo=4", nil)
+ h = queryPolicy.Select(pool, request, nil)
+ if h != pool[1] {
+ t.Error("Expected query policy host to be the second host.")
+ }
+
+ pool[0].setHealthy(false)
+ pool[1].setHealthy(false)
+ h = queryPolicy.Select(pool, request, nil)
+ if h != nil {
+ t.Error("Expected query policy policy host to be nil.")
+ }
+
+ request = httptest.NewRequest(http.MethodGet, "/?foo=aa11&foo=bb22", nil)
+ pool = testPool()
+ h = queryPolicy.Select(pool, request, nil)
+ if h != pool[0] {
+ t.Error("Expected query policy host to be the first host.")
+ }
+}
+
func TestURIHashPolicy(t *testing.T) {
pool := testPool()
- uriPolicy := new(URIHashSelection)
+ uriPolicy := URIHashSelection{}
request := httptest.NewRequest(http.MethodGet, "/test", nil)
h := uriPolicy.Select(pool, request, nil)
@@ -337,8 +618,7 @@ func TestRandomChoicePolicy(t *testing.T) {
pool[2].countRequest(30)
request := httptest.NewRequest(http.MethodGet, "/test", nil)
- randomChoicePolicy := new(RandomChoiceSelection)
- randomChoicePolicy.Choose = 2
+ randomChoicePolicy := RandomChoiceSelection{Choose: 2}
h := randomChoicePolicy.Select(pool, request, nil)
@@ -353,6 +633,14 @@ func TestRandomChoicePolicy(t *testing.T) {
}
func TestCookieHashPolicy(t *testing.T) {
+ ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
+ defer cancel()
+ cookieHashPolicy := CookieHashSelection{}
+ if err := cookieHashPolicy.Provision(ctx); err != nil {
+ t.Errorf("Provision error: %v", err)
+ t.FailNow()
+ }
+
pool := testPool()
pool[0].Dial = "localhost:8080"
pool[1].Dial = "localhost:8081"
@@ -362,7 +650,7 @@ func TestCookieHashPolicy(t *testing.T) {
pool[2].setHealthy(false)
request := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
- cookieHashPolicy := new(CookieHashSelection)
+
h := cookieHashPolicy.Select(pool, request, w)
cookieServer1 := w.Result().Cookies()[0]
if cookieServer1 == nil {
@@ -399,3 +687,59 @@ func TestCookieHashPolicy(t *testing.T) {
t.Error("Expected cookieHashPolicy to set a new cookie.")
}
}
+
+func TestCookieHashPolicyWithFirstFallback(t *testing.T) {
+ ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
+ defer cancel()
+ cookieHashPolicy := CookieHashSelection{
+ FallbackRaw: caddyconfig.JSONModuleObject(FirstSelection{}, "policy", "first", nil),
+ }
+ if err := cookieHashPolicy.Provision(ctx); err != nil {
+ t.Errorf("Provision error: %v", err)
+ t.FailNow()
+ }
+
+ pool := testPool()
+ pool[0].Dial = "localhost:8080"
+ pool[1].Dial = "localhost:8081"
+ pool[2].Dial = "localhost:8082"
+ pool[0].setHealthy(true)
+ pool[1].setHealthy(true)
+ pool[2].setHealthy(true)
+ request := httptest.NewRequest(http.MethodGet, "/test", nil)
+ w := httptest.NewRecorder()
+
+ h := cookieHashPolicy.Select(pool, request, w)
+ cookieServer1 := w.Result().Cookies()[0]
+ if cookieServer1 == nil {
+ t.Fatal("cookieHashPolicy should set a cookie")
+ }
+ if cookieServer1.Name != "lb" {
+ t.Error("cookieHashPolicy should set a cookie with name lb")
+ }
+ if h != pool[0] {
+ t.Errorf("Expected cookieHashPolicy host to be the first only available host, got %s", h)
+ }
+ request = httptest.NewRequest(http.MethodGet, "/test", nil)
+ w = httptest.NewRecorder()
+ request.AddCookie(cookieServer1)
+ h = cookieHashPolicy.Select(pool, request, w)
+ if h != pool[0] {
+ t.Errorf("Expected cookieHashPolicy host to stick to the first host (matching cookie), got %s", h)
+ }
+ s := w.Result().Cookies()
+ if len(s) != 0 {
+ t.Error("Expected cookieHashPolicy to not set a new cookie.")
+ }
+ pool[0].setHealthy(false)
+ request = httptest.NewRequest(http.MethodGet, "/test", nil)
+ w = httptest.NewRecorder()
+ request.AddCookie(cookieServer1)
+ h = cookieHashPolicy.Select(pool, request, w)
+ if h != pool[1] {
+ t.Errorf("Expected cookieHashPolicy to select the next first available host, got %s", h)
+ }
+ if w.Result().Cookies() == nil {
+ t.Error("Expected cookieHashPolicy to set a new cookie.")
+ }
+}
diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go
index 1db107a..155a1df 100644
--- a/modules/caddyhttp/reverseproxy/streaming.go
+++ b/modules/caddyhttp/reverseproxy/streaming.go
@@ -20,6 +20,8 @@ package reverseproxy
import (
"context"
+ "errors"
+ "fmt"
"io"
weakrand "math/rand"
"mime"
@@ -32,32 +34,46 @@ import (
"golang.org/x/net/http/httpguts"
)
-func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) {
+func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header)
// Taken from https://github.com/golang/go/commit/5c489514bc5e61ad9b5b07bd7d8ec65d66a0512a
// We know reqUpType is ASCII, it's checked by the caller.
if !asciiIsPrint(resUpType) {
- h.logger.Debug("backend tried to switch to invalid protocol",
+ logger.Debug("backend tried to switch to invalid protocol",
zap.String("backend_upgrade", resUpType))
return
}
if !asciiEqualFold(reqUpType, resUpType) {
- h.logger.Debug("backend tried to switch to unexpected protocol via Upgrade header",
+ logger.Debug("backend tried to switch to unexpected protocol via Upgrade header",
zap.String("backend_upgrade", resUpType),
zap.String("requested_upgrade", reqUpType))
return
}
- hj, ok := rw.(http.Hijacker)
+ backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
+ logger.Error("internal error: 101 switching protocols response with non-writable body")
+ return
+ }
+
+ // write header first, response headers should not be counted in size
+ // like the rest of handler chain.
+ copyHeader(rw.Header(), res.Header)
+ rw.WriteHeader(res.StatusCode)
+
+ logger.Debug("upgrading connection")
+
+ //nolint:bodyclose
+ conn, brw, hijackErr := http.NewResponseController(rw).Hijack()
+ if errors.Is(hijackErr, http.ErrNotSupported) {
h.logger.Sugar().Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)
return
}
- backConn, ok := res.Body.(io.ReadWriteCloser)
- if !ok {
- h.logger.Error("internal error: 101 switching protocols response with non-writable body")
+
+ if hijackErr != nil {
+ h.logger.Error("hijack failed on protocol switch", zap.Error(hijackErr))
return
}
@@ -74,18 +90,6 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
}()
defer close(backConnCloseCh)
- // write header first, response headers should not be counted in size
- // like the rest of handler chain.
- copyHeader(rw.Header(), res.Header)
- rw.WriteHeader(res.StatusCode)
-
- logger.Debug("upgrading connection")
- conn, brw, err := hj.Hijack()
- if err != nil {
- h.logger.Error("hijack failed on protocol switch", zap.Error(err))
- return
- }
-
start := time.Now()
defer func() {
conn.Close()
@@ -93,7 +97,7 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
}()
if err := brw.Flush(); err != nil {
- h.logger.Debug("response flush", zap.Error(err))
+ logger.Debug("response flush", zap.Error(err))
return
}
@@ -119,10 +123,23 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
spc := switchProtocolCopier{user: conn, backend: backConn}
+ // setup the timeout if requested
+ var timeoutc <-chan time.Time
+ if h.StreamTimeout > 0 {
+ timer := time.NewTimer(time.Duration(h.StreamTimeout))
+ defer timer.Stop()
+ timeoutc = timer.C
+ }
+
errc := make(chan error, 1)
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
- <-errc
+ select {
+ case err := <-errc:
+ logger.Debug("streaming error", zap.Error(err))
+ case time := <-timeoutc:
+ logger.Debug("stream timed out", zap.Time("timeout", time))
+ }
}
// flushInterval returns the p.FlushInterval value, conditionally
@@ -167,38 +184,58 @@ func (h Handler) isBidirectionalStream(req *http.Request, res *http.Response) bo
(ae == "identity" || ae == "")
}
-func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
+func (h Handler) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration, logger *zap.Logger) error {
+ var w io.Writer = dst
+
if flushInterval != 0 {
- if wf, ok := dst.(writeFlusher); ok {
- mlw := &maxLatencyWriter{
- dst: wf,
- latency: flushInterval,
- }
- defer mlw.stop()
+ var mlwLogger *zap.Logger
+ if h.VerboseLogs {
+ mlwLogger = logger.Named("max_latency_writer")
+ } else {
+ mlwLogger = zap.NewNop()
+ }
+ mlw := &maxLatencyWriter{
+ dst: dst,
+ //nolint:bodyclose
+ flush: http.NewResponseController(dst).Flush,
+ latency: flushInterval,
+ logger: mlwLogger,
+ }
+ defer mlw.stop()
- // set up initial timer so headers get flushed even if body writes are delayed
- mlw.flushPending = true
- mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
+ // set up initial timer so headers get flushed even if body writes are delayed
+ mlw.flushPending = true
+ mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
- dst = mlw
- }
+ w = mlw
}
buf := streamingBufPool.Get().(*[]byte)
defer streamingBufPool.Put(buf)
- _, err := h.copyBuffer(dst, src, *buf)
+
+ var copyLogger *zap.Logger
+ if h.VerboseLogs {
+ copyLogger = logger
+ } else {
+ copyLogger = zap.NewNop()
+ }
+
+ _, err := h.copyBuffer(w, src, *buf, copyLogger)
return err
}
// copyBuffer returns any write errors or non-EOF read errors, and the amount
// of bytes written.
-func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
+func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *zap.Logger) (int64, error) {
if len(buf) == 0 {
buf = make([]byte, defaultBufferSize)
}
var written int64
for {
+ logger.Debug("waiting to read from upstream")
nr, rerr := src.Read(buf)
+ logger := logger.With(zap.Int("read", nr))
+ logger.Debug("read from upstream", zap.Error(rerr))
if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
// TODO: this could be useful to know (indeed, it revealed an error in our
// fastcgi PoC earlier; but it's this single error report here that necessitates
@@ -210,12 +247,17 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, er
h.logger.Error("reading from backend", zap.Error(rerr))
}
if nr > 0 {
+ logger.Debug("writing to downstream")
nw, werr := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
}
+ logger.Debug("wrote to downstream",
+ zap.Int("written", nw),
+ zap.Int64("written_total", written),
+ zap.Error(werr))
if werr != nil {
- return written, werr
+ return written, fmt.Errorf("writing: %w", werr)
}
if nr != nw {
return written, io.ErrShortWrite
@@ -223,9 +265,9 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, er
}
if rerr != nil {
if rerr == io.EOF {
- rerr = nil
+ return written, nil
}
- return written, rerr
+ return written, fmt.Errorf("reading: %w", rerr)
}
}
}
@@ -242,10 +284,70 @@ func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func
return func() {
h.connectionsMu.Lock()
delete(h.connections, conn)
+ // if there is no connection left before the connections close timer fires
+ if len(h.connections) == 0 && h.connectionsCloseTimer != nil {
+ // we release the timer that holds the reference to Handler
+ if (*h.connectionsCloseTimer).Stop() {
+ h.logger.Debug("stopped streaming connections close timer - all connections are already closed")
+ }
+ h.connectionsCloseTimer = nil
+ }
h.connectionsMu.Unlock()
}
}
+// closeConnections immediately closes all hijacked connections (both to client and backend).
+func (h *Handler) closeConnections() error {
+ var err error
+ h.connectionsMu.Lock()
+ defer h.connectionsMu.Unlock()
+
+ for _, oc := range h.connections {
+ if oc.gracefulClose != nil {
+ // this is potentially blocking while we have the lock on the connections
+ // map, but that should be OK since the server has in theory shut down
+ // and we are no longer using the connections map
+ gracefulErr := oc.gracefulClose()
+ if gracefulErr != nil && err == nil {
+ err = gracefulErr
+ }
+ }
+ closeErr := oc.conn.Close()
+ if closeErr != nil && err == nil {
+ err = closeErr
+ }
+ }
+ return err
+}
+
+// cleanupConnections closes hijacked connections.
+// Depending on the value of StreamCloseDelay it does that either immediately
+// or sets up a timer that will do that later.
+func (h *Handler) cleanupConnections() error {
+ if h.StreamCloseDelay == 0 {
+ return h.closeConnections()
+ }
+
+ h.connectionsMu.Lock()
+ defer h.connectionsMu.Unlock()
+ // the handler is shut down, no new connection can appear,
+ // so we can skip setting up the timer when there are no connections
+ if len(h.connections) > 0 {
+ delay := time.Duration(h.StreamCloseDelay)
+ h.connectionsCloseTimer = time.AfterFunc(delay, func() {
+ h.logger.Debug("closing streaming connections after delay",
+ zap.Duration("delay", delay))
+ err := h.closeConnections()
+ if err != nil {
+ h.logger.Error("failed to closed connections after delay",
+ zap.Error(err),
+ zap.Duration("delay", delay))
+ }
+ })
+ }
+ return nil
+}
+
// writeCloseControl sends a best-effort Close control message to the given
// WebSocket connection. Thanks to @pascaldekloe who provided inspiration
// from his simple implementation of this I was able to learn from at:
@@ -365,29 +467,30 @@ type openConnection struct {
gracefulClose func() error
}
-type writeFlusher interface {
- io.Writer
- http.Flusher
-}
-
type maxLatencyWriter struct {
- dst writeFlusher
+ dst io.Writer
+ flush func() error
latency time.Duration // non-zero; negative means to flush immediately
mu sync.Mutex // protects t, flushPending, and dst.Flush
t *time.Timer
flushPending bool
+ logger *zap.Logger
}
func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
n, err = m.dst.Write(p)
+ m.logger.Debug("wrote bytes", zap.Int("n", n), zap.Error(err))
if m.latency < 0 {
- m.dst.Flush()
+ m.logger.Debug("flushing immediately")
+ //nolint:errcheck
+ m.flush()
return
}
if m.flushPending {
+ m.logger.Debug("delayed flush already pending")
return
}
if m.t == nil {
@@ -395,6 +498,7 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
} else {
m.t.Reset(m.latency)
}
+ m.logger.Debug("timer set for delayed flush", zap.Duration("duration", m.latency))
m.flushPending = true
return
}
@@ -403,9 +507,12 @@ func (m *maxLatencyWriter) delayedFlush() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
+ m.logger.Debug("delayed flush is not pending")
return
}
- m.dst.Flush()
+ m.logger.Debug("delayed flush")
+ //nolint:errcheck
+ m.flush()
m.flushPending = false
}
@@ -445,5 +552,7 @@ var streamingBufPool = sync.Pool{
},
}
-const defaultBufferSize = 32 * 1024
-const wordSize = int(unsafe.Sizeof(uintptr(0)))
+const (
+ defaultBufferSize = 32 * 1024
+ wordSize = int(unsafe.Sizeof(uintptr(0)))
+)
diff --git a/modules/caddyhttp/reverseproxy/streaming_test.go b/modules/caddyhttp/reverseproxy/streaming_test.go
index 4ed1f1e..3f6da2f 100644
--- a/modules/caddyhttp/reverseproxy/streaming_test.go
+++ b/modules/caddyhttp/reverseproxy/streaming_test.go
@@ -2,8 +2,11 @@ package reverseproxy
import (
"bytes"
+ "net/http/httptest"
"strings"
"testing"
+
+ "github.com/caddyserver/caddy/v2"
)
func TestHandlerCopyResponse(t *testing.T) {
@@ -13,12 +16,15 @@ func TestHandlerCopyResponse(t *testing.T) {
strings.Repeat("a", defaultBufferSize),
strings.Repeat("123456789 123456789 123456789 12", 3000),
}
+
dst := bytes.NewBuffer(nil)
+ recorder := httptest.NewRecorder()
+ recorder.Body = dst
for _, d := range testdata {
src := bytes.NewBuffer([]byte(d))
dst.Reset()
- err := h.copyResponse(dst, src, 0)
+ err := h.copyResponse(recorder, src, 0, caddy.Log())
if err != nil {
t.Errorf("failed with error: %v", err)
}
diff --git a/modules/caddyhttp/reverseproxy/upstreams.go b/modules/caddyhttp/reverseproxy/upstreams.go
index 7a90016..2d21a5c 100644
--- a/modules/caddyhttp/reverseproxy/upstreams.go
+++ b/modules/caddyhttp/reverseproxy/upstreams.go
@@ -8,12 +8,12 @@ import (
"net"
"net/http"
"strconv"
- "strings"
"sync"
"time"
- "github.com/caddyserver/caddy/v2"
"go.uber.org/zap"
+
+ "github.com/caddyserver/caddy/v2"
)
func init() {
@@ -114,7 +114,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
cached := srvs[suAddr]
srvsMu.RUnlock()
if cached.isFresh() {
- return cached.upstreams, nil
+ return allNew(cached.upstreams), nil
}
// otherwise, obtain a write-lock to update the cached value
@@ -126,7 +126,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
// have refreshed it in the meantime before we re-obtained our lock
cached = srvs[suAddr]
if cached.isFresh() {
- return cached.upstreams, nil
+ return allNew(cached.upstreams), nil
}
su.logger.Debug("refreshing SRV upstreams",
@@ -145,7 +145,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
su.logger.Warn("SRV records filtered", zap.Error(err))
}
- upstreams := make([]*Upstream, len(records))
+ upstreams := make([]Upstream, len(records))
for i, rec := range records {
su.logger.Debug("discovered SRV record",
zap.String("target", rec.Target),
@@ -153,7 +153,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
zap.Uint16("priority", rec.Priority),
zap.Uint16("weight", rec.Weight))
addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port)))
- upstreams[i] = &Upstream{Dial: addr}
+ upstreams[i] = Upstream{Dial: addr}
}
// before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full
@@ -170,7 +170,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
upstreams: upstreams,
}
- return upstreams, nil
+ return allNew(upstreams), nil
}
func (su SRVUpstreams) String() string {
@@ -206,13 +206,18 @@ func (SRVUpstreams) formattedAddr(service, proto, name string) string {
type srvLookup struct {
srvUpstreams SRVUpstreams
freshness time.Time
- upstreams []*Upstream
+ upstreams []Upstream
}
func (sl srvLookup) isFresh() bool {
return time.Since(sl.freshness) < time.Duration(sl.srvUpstreams.Refresh)
}
+type IPVersions struct {
+ IPv4 *bool `json:"ipv4,omitempty"`
+ IPv6 *bool `json:"ipv6,omitempty"`
+}
+
// AUpstreams provides upstreams from A/AAAA lookups.
// Results are cached and refreshed at the configured
// refresh interval.
@@ -240,7 +245,14 @@ type AUpstreams struct {
// A negative value disables this.
FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"`
+ // The IP versions to resolve for. By default, both
+ // "ipv4" and "ipv6" will be enabled, which
+ // correspond to A and AAAA records respectively.
+ Versions *IPVersions `json:"versions,omitempty"`
+
resolver *net.Resolver
+
+ logger *zap.Logger
}
// CaddyModule returns the Caddy module information.
@@ -251,7 +263,8 @@ func (AUpstreams) CaddyModule() caddy.ModuleInfo {
}
}
-func (au *AUpstreams) Provision(_ caddy.Context) error {
+func (au *AUpstreams) Provision(ctx caddy.Context) error {
+ au.logger = ctx.Logger()
if au.Refresh == 0 {
au.Refresh = caddy.Duration(time.Minute)
}
@@ -286,14 +299,36 @@ func (au *AUpstreams) Provision(_ caddy.Context) error {
func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
- auStr := repl.ReplaceAll(au.String(), "")
+
+ resolveIpv4 := au.Versions == nil || au.Versions.IPv4 == nil || *au.Versions.IPv4
+ resolveIpv6 := au.Versions == nil || au.Versions.IPv6 == nil || *au.Versions.IPv6
+
+ // Map ipVersion early, so we can use it as part of the cache-key.
+ // This should be fairly inexpensive and comes and the upside of
+ // allowing the same dynamic upstream (name + port combination)
+ // to be used multiple times with different ip versions.
+ //
+ // It also forced a cache-miss if a previously cached dynamic
+ // upstream changes its ip version, e.g. after a config reload,
+ // while keeping the cache-invalidation as simple as it currently is.
+ var ipVersion string
+ switch {
+ case resolveIpv4 && !resolveIpv6:
+ ipVersion = "ip4"
+ case !resolveIpv4 && resolveIpv6:
+ ipVersion = "ip6"
+ default:
+ ipVersion = "ip"
+ }
+
+ auStr := repl.ReplaceAll(au.String()+ipVersion, "")
// first, use a cheap read-lock to return a cached result quickly
aAaaaMu.RLock()
cached := aAaaa[auStr]
aAaaaMu.RUnlock()
if cached.isFresh() {
- return cached.upstreams, nil
+ return allNew(cached.upstreams), nil
}
// otherwise, obtain a write-lock to update the cached value
@@ -305,26 +340,33 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
// have refreshed it in the meantime before we re-obtained our lock
cached = aAaaa[auStr]
if cached.isFresh() {
- return cached.upstreams, nil
+ return allNew(cached.upstreams), nil
}
name := repl.ReplaceAll(au.Name, "")
port := repl.ReplaceAll(au.Port, "")
- ips, err := au.resolver.LookupIPAddr(r.Context(), name)
+ au.logger.Debug("refreshing A upstreams",
+ zap.String("version", ipVersion),
+ zap.String("name", name),
+ zap.String("port", port))
+
+ ips, err := au.resolver.LookupIP(r.Context(), ipVersion, name)
if err != nil {
return nil, err
}
- upstreams := make([]*Upstream, len(ips))
+ upstreams := make([]Upstream, len(ips))
for i, ip := range ips {
- upstreams[i] = &Upstream{
+ au.logger.Debug("discovered A record",
+ zap.String("ip", ip.String()))
+ upstreams[i] = Upstream{
Dial: net.JoinHostPort(ip.String(), port),
}
}
// before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full
- if cached.freshness.IsZero() && len(srvs) >= 100 {
+ if cached.freshness.IsZero() && len(aAaaa) >= 100 {
for randomKey := range aAaaa {
delete(aAaaa, randomKey)
break
@@ -337,7 +379,7 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
upstreams: upstreams,
}
- return upstreams, nil
+ return allNew(upstreams), nil
}
func (au AUpstreams) String() string { return net.JoinHostPort(au.Name, au.Port) }
@@ -345,7 +387,7 @@ func (au AUpstreams) String() string { return net.JoinHostPort(au.Name, au.Port)
type aLookup struct {
aUpstreams AUpstreams
freshness time.Time
- upstreams []*Upstream
+ upstreams []Upstream
}
func (al aLookup) isFresh() bool {
@@ -439,16 +481,9 @@ type UpstreamResolver struct {
// and ensures they're ready to be used.
func (u *UpstreamResolver) ParseAddresses() error {
for _, v := range u.Addresses {
- addr, err := caddy.ParseNetworkAddress(v)
+ addr, err := caddy.ParseNetworkAddressWithDefaults(v, "udp", 53)
if err != nil {
- // If a port wasn't specified for the resolver,
- // try defaulting to 53 and parse again
- if strings.Contains(err.Error(), "missing port in address") {
- addr, err = caddy.ParseNetworkAddress(v + ":53")
- }
- if err != nil {
- return err
- }
+ return err
}
if addr.PortRangeSize() != 1 {
return fmt.Errorf("resolver address must have exactly one address; cannot call %v", addr)
@@ -458,6 +493,14 @@ func (u *UpstreamResolver) ParseAddresses() error {
return nil
}
+func allNew(upstreams []Upstream) []*Upstream {
+ results := make([]*Upstream, len(upstreams))
+ for i := range upstreams {
+ results[i] = &Upstream{Dial: upstreams[i].Dial}
+ }
+ return results
+}
+
var (
srvs = make(map[string]srvLookup)
srvsMu sync.RWMutex