diff options
Diffstat (limited to 'modules/caddyhttp')
-rw-r--r-- | modules/caddyhttp/reverseproxy/admin.go | 3 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/caddyfile.go | 221 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/healthchecks.go | 96 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/hosts.go | 99 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/httptransport.go | 24 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/reverseproxy.go | 274 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/selectionpolicies_test.go | 120 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/upstreams.go | 377 |
8 files changed, 933 insertions, 281 deletions
diff --git a/modules/caddyhttp/reverseproxy/admin.go b/modules/caddyhttp/reverseproxy/admin.go index 25685a3..81ec435 100644 --- a/modules/caddyhttp/reverseproxy/admin.go +++ b/modules/caddyhttp/reverseproxy/admin.go @@ -87,7 +87,7 @@ func (adminUpstreams) handleUpstreams(w http.ResponseWriter, r *http.Request) er return false } - upstream, ok := val.(*upstreamHost) + upstream, ok := val.(*Host) if !ok { rangeErr = caddy.APIError{ HTTPStatus: http.StatusInternalServerError, @@ -98,7 +98,6 @@ func (adminUpstreams) handleUpstreams(w http.ResponseWriter, r *http.Request) er results = append(results, upstreamStatus{ Address: address, - Healthy: !upstream.Unhealthy(), NumRequests: upstream.NumRequests(), Fails: upstream.Fails(), }) diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go index e127237..f4b1636 100644 --- a/modules/caddyhttp/reverseproxy/caddyfile.go +++ b/modules/caddyhttp/reverseproxy/caddyfile.go @@ -53,6 +53,7 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) // reverse_proxy [<matcher>] [<upstreams...>] { // # upstreams // to <upstreams...> +// dynamic <name> [...] // // # load balancing // lb_policy <name> [<options...>] @@ -190,6 +191,25 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } } + case "dynamic": + if !d.NextArg() { + return d.ArgErr() + } + if h.DynamicUpstreams != nil { + return d.Err("dynamic upstreams already specified") + } + dynModule := d.Val() + modID := "http.reverse_proxy.upstreams." + dynModule + unm, err := caddyfile.UnmarshalModule(d, modID) + if err != nil { + return err + } + source, ok := unm.(UpstreamSource) + if !ok { + return d.Errf("module %s (%T) is not an UpstreamSource", modID, unm) + } + h.DynamicUpstreamsRaw = caddyconfig.JSONModuleObject(source, "source", dynModule, nil) + case "lb_policy": if !d.NextArg() { return d.ArgErr() @@ -749,6 +769,7 @@ func (h *Handler) FinalizeUnmarshalCaddyfile(helper httpcaddyfile.Helper) error // dial_fallback_delay <duration> // response_header_timeout <duration> // expect_continue_timeout <duration> +// resolvers <resolvers...> // tls // tls_client_auth <automate_name> | <cert_file> <key_file> // tls_insecure_skip_verify @@ -839,6 +860,15 @@ func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } h.ExpectContinueTimeout = caddy.Duration(dur) + case "resolvers": + if h.Resolver == nil { + h.Resolver = new(UpstreamResolver) + } + h.Resolver.Addresses = d.RemainingArgs() + if len(h.Resolver.Addresses) == 0 { + return d.Errf("must specify at least one resolver address") + } + case "tls_client_auth": if h.TLS == nil { h.TLS = new(TLSConfig) @@ -989,10 +1019,201 @@ func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { return nil } +// UnmarshalCaddyfile deserializes Caddyfile tokens into h. +// +// dynamic srv [<name>] { +// service <service> +// proto <proto> +// name <name> +// refresh <interval> +// resolvers <resolvers...> +// dial_timeout <timeout> +// dial_fallback_delay <timeout> +// } +// +func (u *SRVUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + for d.Next() { + args := d.RemainingArgs() + if len(args) > 1 { + return d.ArgErr() + } + if len(args) > 0 { + u.Name = args[0] + } + + for d.NextBlock(0) { + switch d.Val() { + case "service": + if !d.NextArg() { + return d.ArgErr() + } + if u.Service != "" { + return d.Errf("srv service has already been specified") + } + u.Service = d.Val() + + case "proto": + if !d.NextArg() { + return d.ArgErr() + } + if u.Proto != "" { + return d.Errf("srv proto has already been specified") + } + u.Proto = d.Val() + + case "name": + if !d.NextArg() { + return d.ArgErr() + } + if u.Name != "" { + return d.Errf("srv name has already been specified") + } + u.Name = d.Val() + + case "refresh": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("parsing refresh interval duration: %v", err) + } + u.Refresh = caddy.Duration(dur) + + case "resolvers": + if u.Resolver == nil { + u.Resolver = new(UpstreamResolver) + } + u.Resolver.Addresses = d.RemainingArgs() + if len(u.Resolver.Addresses) == 0 { + return d.Errf("must specify at least one resolver address") + } + + case "dial_timeout": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad timeout value '%s': %v", d.Val(), err) + } + u.DialTimeout = caddy.Duration(dur) + + case "dial_fallback_delay": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad delay value '%s': %v", d.Val(), err) + } + u.FallbackDelay = caddy.Duration(dur) + + default: + return d.Errf("unrecognized srv option '%s'", d.Val()) + } + } + } + + return nil +} + +// UnmarshalCaddyfile deserializes Caddyfile tokens into h. +// +// dynamic a [<name> <port] { +// name <name> +// port <port> +// refresh <interval> +// resolvers <resolvers...> +// dial_timeout <timeout> +// dial_fallback_delay <timeout> +// } +// +func (u *AUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + for d.Next() { + args := d.RemainingArgs() + if len(args) > 2 { + return d.ArgErr() + } + if len(args) > 0 { + u.Name = args[0] + u.Port = args[1] + } + + for d.NextBlock(0) { + switch d.Val() { + case "name": + if !d.NextArg() { + return d.ArgErr() + } + if u.Name != "" { + return d.Errf("a name has already been specified") + } + u.Name = d.Val() + + case "port": + if !d.NextArg() { + return d.ArgErr() + } + if u.Port != "" { + return d.Errf("a port has already been specified") + } + u.Port = d.Val() + + case "refresh": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("parsing refresh interval duration: %v", err) + } + u.Refresh = caddy.Duration(dur) + + case "resolvers": + if u.Resolver == nil { + u.Resolver = new(UpstreamResolver) + } + u.Resolver.Addresses = d.RemainingArgs() + if len(u.Resolver.Addresses) == 0 { + return d.Errf("must specify at least one resolver address") + } + + case "dial_timeout": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad timeout value '%s': %v", d.Val(), err) + } + u.DialTimeout = caddy.Duration(dur) + + case "dial_fallback_delay": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad delay value '%s': %v", d.Val(), err) + } + u.FallbackDelay = caddy.Duration(dur) + + default: + return d.Errf("unrecognized srv option '%s'", d.Val()) + } + } + } + + return nil +} + const matcherPrefix = "@" // Interface guards var ( _ caddyfile.Unmarshaler = (*Handler)(nil) _ caddyfile.Unmarshaler = (*HTTPTransport)(nil) + _ caddyfile.Unmarshaler = (*SRVUpstreams)(nil) + _ caddyfile.Unmarshaler = (*AUpstreams)(nil) ) diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go index 230bf3a..317b283 100644 --- a/modules/caddyhttp/reverseproxy/healthchecks.go +++ b/modules/caddyhttp/reverseproxy/healthchecks.go @@ -18,7 +18,6 @@ import ( "context" "fmt" "io" - "log" "net" "net/http" "net/url" @@ -37,12 +36,32 @@ import ( type HealthChecks struct { // Active health checks run in the background on a timer. To // minimally enable active health checks, set either path or - // port (or both). + // port (or both). Note that active health check status + // (healthy/unhealthy) is stored per-proxy-handler, not + // globally; this allows different handlers to use different + // criteria to decide what defines a healthy backend. + // + // Active health checks do not run for dynamic upstreams. Active *ActiveHealthChecks `json:"active,omitempty"` // Passive health checks monitor proxied requests for errors or timeouts. // To minimally enable passive health checks, specify at least an empty - // config object. + // config object. Passive health check state is shared (stored globally), + // so a failure from one handler will be counted by all handlers; but + // the tolerances or standards for what defines healthy/unhealthy backends + // is configured per-proxy-handler. + // + // Passive health checks technically do operate on dynamic upstreams, + // but are only effective for very busy proxies where the list of + // upstreams is mostly stable. This is because the shared/global + // state of upstreams is cleaned up when the upstreams are no longer + // used. Since dynamic upstreams are allocated dynamically at each + // request (specifically, each iteration of the proxy loop per request), + // they are also cleaned up after every request. Thus, if there is a + // moment when no requests are actively referring to a particular + // upstream host, the passive health check state will be reset because + // it will be garbage-collected. It is usually better for the dynamic + // upstream module to only return healthy, available backends instead. Passive *PassiveHealthChecks `json:"passive,omitempty"` } @@ -50,8 +69,7 @@ type HealthChecks struct { // health checks (that is, health checks which occur in a // background goroutine independently). type ActiveHealthChecks struct { - // The path to use for health checks. - // DEPRECATED: Use 'uri' instead. + // DEPRECATED: Use 'uri' instead. This field will be removed. TODO: remove this field Path string `json:"path,omitempty"` // The URI (path and query) to use for health checks @@ -132,7 +150,9 @@ type CircuitBreaker interface { func (h *Handler) activeHealthChecker() { defer func() { if err := recover(); err != nil { - log.Printf("[PANIC] active health checks: %v\n%s", err, debug.Stack()) + h.HealthChecks.Active.logger.Error("active health checker panicked", + zap.Any("error", err), + zap.ByteString("stack", debug.Stack())) } }() ticker := time.NewTicker(time.Duration(h.HealthChecks.Active.Interval)) @@ -155,7 +175,9 @@ func (h *Handler) doActiveHealthCheckForAllHosts() { go func(upstream *Upstream) { defer func() { if err := recover(); err != nil { - log.Printf("[PANIC] active health check: %v\n%s", err, debug.Stack()) + h.HealthChecks.Active.logger.Error("active health check panicked", + zap.Any("error", err), + zap.ByteString("stack", debug.Stack())) } }() @@ -195,7 +217,7 @@ func (h *Handler) doActiveHealthCheckForAllHosts() { // so use a fake Host value instead; unix sockets are usually local hostAddr = "localhost" } - err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: dialAddr}, hostAddr, upstream.Host) + err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: dialAddr}, hostAddr, upstream) if err != nil { h.HealthChecks.Active.logger.Error("active health check failed", zap.String("address", hostAddr), @@ -206,14 +228,14 @@ func (h *Handler) doActiveHealthCheckForAllHosts() { } } -// doActiveHealthCheck performs a health check to host which +// doActiveHealthCheck performs a health check to upstream which // can be reached at address hostAddr. The actual address for // the request will be built according to active health checker // config. The health status of the host will be updated // according to whether it passes the health check. An error is // returned only if the health check fails to occur or if marking // the host's health status fails. -func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host Host) error { +func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstream *Upstream) error { // create the URL for the request that acts as a health check scheme := "http" if ht, ok := h.Transport.(TLSTransport); ok && ht.TLSEnabled() { @@ -269,10 +291,7 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host H zap.String("host", hostAddr), zap.Error(err), ) - _, err2 := host.SetHealthy(false) - if err2 != nil { - return fmt.Errorf("marking unhealthy: %v", err2) - } + upstream.setHealthy(false) return nil } var body io.Reader = resp.Body @@ -292,10 +311,7 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host H zap.Int("status_code", resp.StatusCode), zap.String("host", hostAddr), ) - _, err := host.SetHealthy(false) - if err != nil { - return fmt.Errorf("marking unhealthy: %v", err) - } + upstream.setHealthy(false) return nil } } else if resp.StatusCode < 200 || resp.StatusCode >= 400 { @@ -303,10 +319,7 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host H zap.Int("status_code", resp.StatusCode), zap.String("host", hostAddr), ) - _, err := host.SetHealthy(false) - if err != nil { - return fmt.Errorf("marking unhealthy: %v", err) - } + upstream.setHealthy(false) return nil } @@ -318,33 +331,21 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host H zap.String("host", hostAddr), zap.Error(err), ) - _, err := host.SetHealthy(false) - if err != nil { - return fmt.Errorf("marking unhealthy: %v", err) - } + upstream.setHealthy(false) return nil } if !h.HealthChecks.Active.bodyRegexp.Match(bodyBytes) { h.HealthChecks.Active.logger.Info("response body failed expectations", zap.String("host", hostAddr), ) - _, err := host.SetHealthy(false) - if err != nil { - return fmt.Errorf("marking unhealthy: %v", err) - } + upstream.setHealthy(false) return nil } } // passed health check parameters, so mark as healthy - swapped, err := host.SetHealthy(true) - if swapped { - h.HealthChecks.Active.logger.Info("host is up", - zap.String("host", hostAddr), - ) - } - if err != nil { - return fmt.Errorf("marking healthy: %v", err) + if upstream.setHealthy(true) { + h.HealthChecks.Active.logger.Info("host is up", zap.String("host", hostAddr)) } return nil @@ -366,7 +367,7 @@ func (h *Handler) countFailure(upstream *Upstream) { } // count failure immediately - err := upstream.Host.CountFail(1) + err := upstream.Host.countFail(1) if err != nil { h.HealthChecks.Passive.logger.Error("could not count failure", zap.String("host", upstream.Dial), @@ -375,14 +376,23 @@ func (h *Handler) countFailure(upstream *Upstream) { } // forget it later - go func(host Host, failDuration time.Duration) { + go func(host *Host, failDuration time.Duration) { defer func() { if err := recover(); err != nil { - log.Printf("[PANIC] health check failure forgetter: %v\n%s", err, debug.Stack()) + h.HealthChecks.Passive.logger.Error("passive health check failure forgetter panicked", + zap.Any("error", err), + zap.ByteString("stack", debug.Stack())) } }() - time.Sleep(failDuration) - err := host.CountFail(-1) + timer := time.NewTimer(failDuration) + select { + case <-h.ctx.Done(): + if !timer.Stop() { + <-timer.C + } + case <-timer.C: + } + err := host.countFail(-1) if err != nil { h.HealthChecks.Passive.logger.Error("could not forget failure", zap.String("host", upstream.Dial), diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go index b9817d2..a973ecb 100644 --- a/modules/caddyhttp/reverseproxy/hosts.go +++ b/modules/caddyhttp/reverseproxy/hosts.go @@ -26,44 +26,14 @@ import ( "github.com/caddyserver/caddy/v2/modules/caddyhttp" ) -// Host represents a remote host which can be proxied to. -// Its methods must be safe for concurrent use. -type Host interface { - // NumRequests returns the number of requests - // currently in process with the host. - NumRequests() int - - // Fails returns the count of recent failures. - Fails() int - - // Unhealthy returns true if the backend is unhealthy. - Unhealthy() bool - - // CountRequest atomically counts the given number of - // requests as currently in process with the host. The - // count should not go below 0. - CountRequest(int) error - - // CountFail atomically counts the given number of - // failures with the host. The count should not go - // below 0. - CountFail(int) error - - // SetHealthy atomically marks the host as either - // healthy (true) or unhealthy (false). If the given - // status is the same, this should be a no-op and - // return false. It returns true if the status was - // changed; i.e. if it is now different from before. - SetHealthy(bool) (bool, error) -} - // UpstreamPool is a collection of upstreams. type UpstreamPool []*Upstream // Upstream bridges this proxy's configuration to the // state of the backend host it is correlated with. +// Upstream values must not be copied. type Upstream struct { - Host `json:"-"` + *Host `json:"-"` // The [network address](/docs/conventions#network-addresses) // to dial to connect to the upstream. Must represent precisely @@ -77,6 +47,10 @@ type Upstream struct { // backends is down. Also be aware of open proxy vulnerabilities. Dial string `json:"dial,omitempty"` + // DEPRECATED: Use the SRVUpstreams module instead + // (http.reverse_proxy.upstreams.srv). This field will be + // removed in a future version of Caddy. TODO: Remove this field. + // // If DNS SRV records are used for service discovery with this // upstream, specify the DNS name for which to look up SRV // records here, instead of specifying a dial address. @@ -95,6 +69,7 @@ type Upstream struct { activeHealthCheckPort int healthCheckPolicy *PassiveHealthChecks cb CircuitBreaker + unhealthy int32 // accessed atomically; status from active health checker } func (u Upstream) String() string { @@ -117,7 +92,7 @@ func (u *Upstream) Available() bool { // is currently known to be healthy or "up". // It consults the circuit breaker, if any. func (u *Upstream) Healthy() bool { - healthy := !u.Host.Unhealthy() + healthy := u.healthy() if healthy && u.healthCheckPolicy != nil { healthy = u.Host.Fails() < u.healthCheckPolicy.MaxFails } @@ -142,7 +117,7 @@ func (u *Upstream) fillDialInfo(r *http.Request) (DialInfo, error) { var addr caddy.NetworkAddress if u.LookupSRV != "" { - // perform DNS lookup for SRV records and choose one + // perform DNS lookup for SRV records and choose one - TODO: deprecated srvName := repl.ReplaceAll(u.LookupSRV, "") _, records, err := net.DefaultResolver.LookupSRV(r.Context(), "", "", srvName) if err != nil { @@ -174,59 +149,67 @@ func (u *Upstream) fillDialInfo(r *http.Request) (DialInfo, error) { }, nil } -// upstreamHost is the basic, in-memory representation -// of the state of a remote host. It implements the -// Host interface. -type upstreamHost struct { +func (u *Upstream) fillHost() { + host := new(Host) + existingHost, loaded := hosts.LoadOrStore(u.String(), host) + if loaded { + host = existingHost.(*Host) + } + u.Host = host +} + +// Host is the basic, in-memory representation of the state of a remote host. +// Its fields are accessed atomically and Host values must not be copied. +type Host struct { numRequests int64 // must be 64-bit aligned on 32-bit systems (see https://golang.org/pkg/sync/atomic/#pkg-note-BUG) fails int64 - unhealthy int32 } // NumRequests returns the number of active requests to the upstream. -func (uh *upstreamHost) NumRequests() int { - return int(atomic.LoadInt64(&uh.numRequests)) +func (h *Host) NumRequests() int { + return int(atomic.LoadInt64(&h.numRequests)) } // Fails returns the number of recent failures with the upstream. -func (uh *upstreamHost) Fails() int { - return int(atomic.LoadInt64(&uh.fails)) -} - -// Unhealthy returns whether the upstream is healthy. -func (uh *upstreamHost) Unhealthy() bool { - return atomic.LoadInt32(&uh.unhealthy) == 1 +func (h *Host) Fails() int { + return int(atomic.LoadInt64(&h.fails)) } -// CountRequest mutates the active request count by +// countRequest mutates the active request count by // delta. It returns an error if the adjustment fails. -func (uh *upstreamHost) CountRequest(delta int) error { - result := atomic.AddInt64(&uh.numRequests, int64(delta)) +func (h *Host) countRequest(delta int) error { + result := atomic.AddInt64(&h.numRequests, int64(delta)) if result < 0 { return fmt.Errorf("count below 0: %d", result) } return nil } -// CountFail mutates the recent failures count by +// countFail mutates the recent failures count by // delta. It returns an error if the adjustment fails. -func (uh *upstreamHost) CountFail(delta int) error { - result := atomic.AddInt64(&uh.fails, int64(delta)) +func (h *Host) countFail(delta int) error { + result := atomic.AddInt64(&h.fails, int64(delta)) if result < 0 { return fmt.Errorf("count below 0: %d", result) } return nil } +// healthy returns true if the upstream is not actively marked as unhealthy. +// (This returns the status only from the "active" health checks.) +func (u *Upstream) healthy() bool { + return atomic.LoadInt32(&u.unhealthy) == 0 +} + // SetHealthy sets the upstream has healthy or unhealthy -// and returns true if the new value is different. -func (uh *upstreamHost) SetHealthy(healthy bool) (bool, error) { +// and returns true if the new value is different. This +// sets the status only for the "active" health checks. +func (u *Upstream) setHealthy(healthy bool) bool { var unhealthy, compare int32 = 1, 0 if healthy { unhealthy, compare = 0, 1 } - swapped := atomic.CompareAndSwapInt32(&uh.unhealthy, compare, unhealthy) - return swapped, nil + return atomic.CompareAndSwapInt32(&u.unhealthy, compare, unhealthy) } // DialInfo contains information needed to dial a diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index 4be51af..f7472be 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -168,15 +168,9 @@ func (h *HTTPTransport) NewTransport(ctx caddy.Context) (*http.Transport, error) } if h.Resolver != nil { - for _, v := range h.Resolver.Addresses { - addr, err := caddy.ParseNetworkAddress(v) - if err != nil { - return nil, err - } - if addr.PortRangeSize() != 1 { - return nil, fmt.Errorf("resolver address must have exactly one address; cannot call %v", addr) - } - h.Resolver.netAddrs = append(h.Resolver.netAddrs, addr) + err := h.Resolver.ParseAddresses() + if err != nil { + return nil, err } d := &net.Dialer{ Timeout: time.Duration(h.DialTimeout), @@ -406,18 +400,6 @@ func (t TLSConfig) MakeTLSClientConfig(ctx caddy.Context) (*tls.Config, error) { return cfg, nil } -// UpstreamResolver holds the set of addresses of DNS resolvers of -// upstream addresses -type UpstreamResolver struct { - // The addresses of DNS resolvers to use when looking up the addresses of proxy upstreams. - // It accepts [network addresses](/docs/conventions#network-addresses) - // with port range of only 1. If the host is an IP address, it will be dialed directly to resolve the upstream server. - // If the host is not an IP address, the addresses are resolved using the [name resolution convention](https://golang.org/pkg/net/#hdr-Name_Resolution) of the Go standard library. - // If the array contains more than 1 resolver address, one is chosen at random. - Addresses []string `json:"addresses,omitempty"` - netAddrs []caddy.NetworkAddress -} - // KeepAlive holds configuration pertaining to HTTP Keep-Alive. type KeepAlive struct { // Whether HTTP Keep-Alive is enabled. Default: true diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index a5bdc31..3355f0b 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -78,9 +78,20 @@ type Handler struct { // up or down. Down backends will not be proxied to. HealthChecks *HealthChecks `json:"health_checks,omitempty"` - // Upstreams is the list of backends to proxy to. + // Upstreams is the static list of backends to proxy to. Upstreams UpstreamPool `json:"upstreams,omitempty"` + // A module for retrieving the list of upstreams dynamically. Dynamic + // upstreams are retrieved at every iteration of the proxy loop for + // each request (i.e. before every proxy attempt within every request). + // Active health checks do not work on dynamic upstreams, and passive + // health checks are only effective on dynamic upstreams if the proxy + // server is busy enough that concurrent requests to the same backends + // are continuous. Instead of health checks for dynamic upstreams, it + // is recommended that the dynamic upstream module only return available + // backends in the first place. + DynamicUpstreamsRaw json.RawMessage `json:"dynamic_upstreams,omitempty" caddy:"namespace=http.reverse_proxy.upstreams inline_key=source"` + // Adjusts how often to flush the response buffer. By default, // no periodic flushing is done. A negative value disables // response buffering, and flushes immediately after each @@ -137,8 +148,9 @@ type Handler struct { // - `{http.reverse_proxy.header.*}` The headers from the response HandleResponse []caddyhttp.ResponseHandler `json:"handle_response,omitempty"` - Transport http.RoundTripper `json:"-"` - CB CircuitBreaker `json:"-"` + Transport http.RoundTripper `json:"-"` + CB CircuitBreaker `json:"-"` + DynamicUpstreams UpstreamSource `json:"-"` // Holds the parsed CIDR ranges from TrustedProxies trustedProxies []*net.IPNet @@ -166,7 +178,7 @@ func (h *Handler) Provision(ctx caddy.Context) error { h.ctx = ctx h.logger = ctx.Logger(h) - // verify SRV compatibility + // verify SRV compatibility - TODO: LookupSRV deprecated; will be removed for i, v := range h.Upstreams { if v.LookupSRV == "" { continue @@ -201,6 +213,13 @@ func (h *Handler) Provision(ctx caddy.Context) error { } h.CB = mod.(CircuitBreaker) } + if h.DynamicUpstreamsRaw != nil { + mod, err := ctx.LoadModule(h, "DynamicUpstreamsRaw") + if err != nil { + return fmt.Errorf("loading upstream source module: %v", err) + } + h.DynamicUpstreams = mod.(UpstreamSource) + } // parse trusted proxy CIDRs ahead of time for _, str := range h.TrustedProxies { @@ -270,38 +289,8 @@ func (h *Handler) Provision(ctx caddy.Context) error { } // set up upstreams - for _, upstream := range h.Upstreams { - // create or get the host representation for this upstream - var host Host = new(upstreamHost) - existingHost, loaded := hosts.LoadOrStore(upstream.String(), host) - if loaded { - host = existingHost.(Host) - } - upstream.Host = host - - // give it the circuit breaker, if any - upstream.cb = h.CB - - // if the passive health checker has a non-zero UnhealthyRequestCount - // but the upstream has no MaxRequests set (they are the same thing, - // but the passive health checker is a default value for for upstreams - // without MaxRequests), copy the value into this upstream, since the - // value in the upstream (MaxRequests) is what is used during - // availability checks - if h.HealthChecks != nil && h.HealthChecks.Passive != nil { - h.HealthChecks.Passive.logger = h.logger.Named("health_checker.passive") - if h.HealthChecks.Passive.UnhealthyRequestCount > 0 && - upstream.MaxRequests == 0 { - upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount - } - } - - // upstreams need independent access to the passive - // health check policy because passive health checks - // run without access to h. - if h.HealthChecks != nil { - upstream.healthCheckPolicy = h.HealthChecks.Passive - } + for _, u := range h.Upstreams { + h.provisionUpstream(u) } if h.HealthChecks != nil { @@ -413,79 +402,127 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht repl.Set("http.reverse_proxy.duration", time.Since(start)) }() + // in the proxy loop, each iteration is an attempt to proxy the request, + // and because we may retry some number of times, carry over the error + // from previous tries because of the nuances of load balancing & retries var proxyErr error for { - // choose an available upstream - upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, clonedReq, w) - if upstream == nil { - if proxyErr == nil { - proxyErr = fmt.Errorf("no upstreams available") - } - if !h.LoadBalancing.tryAgain(h.ctx, start, proxyErr, clonedReq) { - break - } - continue + var done bool + done, proxyErr = h.proxyLoopIteration(clonedReq, w, proxyErr, start, repl, reqHeader, reqHost, next) + if done { + break } + } + + if proxyErr != nil { + return statusError(proxyErr) + } - // the dial address may vary per-request if placeholders are - // used, so perform those replacements here; the resulting - // DialInfo struct should have valid network address syntax - dialInfo, err := upstream.fillDialInfo(clonedReq) + return nil +} + +// proxyLoopIteration implements an iteration of the proxy loop. Despite the enormous amount of local state +// that has to be passed in, we brought this into its own method so that we could run defer more easily. +// It returns true when the loop is done and should break; false otherwise. The error value returned should +// be assigned to the proxyErr value for the next iteration of the loop (or the error handled after break). +func (h *Handler) proxyLoopIteration(r *http.Request, w http.ResponseWriter, proxyErr error, start time.Time, + repl *caddy.Replacer, reqHeader http.Header, reqHost string, next caddyhttp.Handler) (bool, error) { + // get the updated list of upstreams + upstreams := h.Upstreams + if h.DynamicUpstreams != nil { + dUpstreams, err := h.DynamicUpstreams.GetUpstreams(r) if err != nil { - return statusError(fmt.Errorf("making dial info: %v", err)) + h.logger.Error("failed getting dynamic upstreams; falling back to static upstreams", zap.Error(err)) + } else { + upstreams = dUpstreams + for _, dUp := range dUpstreams { + h.provisionUpstream(dUp) + } + h.logger.Debug("provisioned dynamic upstreams", zap.Int("count", len(dUpstreams))) + defer func() { + // these upstreams are dynamic, so they are only used for this iteration + // of the proxy loop; be sure to let them go away when we're done with them + for _, upstream := range dUpstreams { + _, _ = hosts.Delete(upstream.String()) + } + }() } + } - // attach to the request information about how to dial the upstream; - // this is necessary because the information cannot be sufficiently - // or satisfactorily represented in a URL - caddyhttp.SetVar(r.Context(), dialInfoVarKey, dialInfo) - - // set placeholders with information about this upstream - repl.Set("http.reverse_proxy.upstream.address", dialInfo.String()) - repl.Set("http.reverse_proxy.upstream.hostport", dialInfo.Address) - repl.Set("http.reverse_proxy.upstream.host", dialInfo.Host) - repl.Set("http.reverse_proxy.upstream.port", dialInfo.Port) - repl.Set("http.reverse_proxy.upstream.requests", upstream.Host.NumRequests()) - repl.Set("http.reverse_proxy.upstream.max_requests", upstream.MaxRequests) - repl.Set("http.reverse_proxy.upstream.fails", upstream.Host.Fails()) - - // mutate request headers according to this upstream; - // because we're in a retry loop, we have to copy - // headers (and the Host value) from the original - // so that each retry is identical to the first - if h.Headers != nil && h.Headers.Request != nil { - clonedReq.Header = make(http.Header) - copyHeader(clonedReq.Header, reqHeader) - clonedReq.Host = reqHost - h.Headers.Request.ApplyToRequest(clonedReq) + // choose an available upstream + upstream := h.LoadBalancing.SelectionPolicy.Select(upstreams, r, w) + if upstream == nil { + if proxyErr == nil { + proxyErr = fmt.Errorf("no upstreams available") } - - // proxy the request to that upstream - proxyErr = h.reverseProxy(w, clonedReq, repl, dialInfo, next) - if proxyErr == nil || proxyErr == context.Canceled { - // context.Canceled happens when the downstream client - // cancels the request, which is not our failure - return nil + if !h.LoadBalancing.tryAgain(h.ctx, start, proxyErr, r) { + return true, proxyErr } + return false, proxyErr + } - // if the roundtrip was successful, don't retry the request or - // ding the health status of the upstream (an error can still - // occur after the roundtrip if, for example, a response handler - // after the roundtrip returns an error) - if succ, ok := proxyErr.(roundtripSucceeded); ok { - return succ.error - } + // the dial address may vary per-request if placeholders are + // used, so perform those replacements here; the resulting + // DialInfo struct should have valid network address syntax + dialInfo, err := upstream.fillDialInfo(r) + if err != nil { + return true, fmt.Errorf("making dial info: %v", err) + } - // remember this failure (if enabled) - h.countFailure(upstream) + h.logger.Debug("selected upstream", + zap.String("dial", dialInfo.Address), + zap.Int("total_upstreams", len(upstreams))) - // if we've tried long enough, break - if !h.LoadBalancing.tryAgain(h.ctx, start, proxyErr, clonedReq) { - break - } + // attach to the request information about how to dial the upstream; + // this is necessary because the information cannot be sufficiently + // or satisfactorily represented in a URL + caddyhttp.SetVar(r.Context(), dialInfoVarKey, dialInfo) + + // set placeholders with information about this upstream + repl.Set("http.reverse_proxy.upstream.address", dialInfo.String()) + repl.Set("http.reverse_proxy.upstream.hostport", dialInfo.Address) + repl.Set("http.reverse_proxy.upstream.host", dialInfo.Host) + repl.Set("http.reverse_proxy.upstream.port", dialInfo.Port) + repl.Set("http.reverse_proxy.upstream.requests", upstream.Host.NumRequests()) + repl.Set("http.reverse_proxy.upstream.max_requests", upstream.MaxRequests) + repl.Set("http.reverse_proxy.upstream.fails", upstream.Host.Fails()) + + // mutate request headers according to this upstream; + // because we're in a retry loop, we have to copy + // headers (and the r.Host value) from the original + // so that each retry is identical to the first + if h.Headers != nil && h.Headers.Request != nil { + r.Header = make(http.Header) + copyHeader(r.Header, reqHeader) + r.Host = reqHost + h.Headers.Request.ApplyToRequest(r) } - return statusError(proxyErr) + // proxy the request to that upstream + proxyErr = h.reverseProxy(w, r, repl, dialInfo, next) + if proxyErr == nil || proxyErr == context.Canceled { + // context.Canceled happens when the downstream client + // cancels the request, which is not our failure + return true, nil + } + + // if the roundtrip was successful, don't retry the request or + // ding the health status of the upstream (an error can still + // occur after the roundtrip if, for example, a response handler + // after the roundtrip returns an error) + if succ, ok := proxyErr.(roundtripSucceeded); ok { + return true, succ.error + } + + // remember this failure (if enabled) + h.countFailure(upstream) + + // if we've tried long enough, break + if !h.LoadBalancing.tryAgain(h.ctx, start, proxyErr, r) { + return true, proxyErr + } + + return false, proxyErr } // prepareRequest clones req so that it can be safely modified without @@ -651,9 +688,9 @@ func (h Handler) addForwardedHeaders(req *http.Request) error { // (This method is mostly the beginning of what was borrowed from the net/http/httputil package in the // Go standard library which was used as the foundation.) func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, repl *caddy.Replacer, di DialInfo, next caddyhttp.Handler) error { - _ = di.Upstream.Host.CountRequest(1) + _ = di.Upstream.Host.countRequest(1) //nolint:errcheck - defer di.Upstream.Host.CountRequest(-1) + defer di.Upstream.Host.countRequest(-1) // point the request to this upstream h.directRequest(req, di) @@ -905,6 +942,35 @@ func (Handler) directRequest(req *http.Request, di DialInfo) { req.URL.Host = reqHost } +func (h Handler) provisionUpstream(upstream *Upstream) { + // create or get the host representation for this upstream + upstream.fillHost() + + // give it the circuit breaker, if any + upstream.cb = h.CB + + // if the passive health checker has a non-zero UnhealthyRequestCount + // but the upstream has no MaxRequests set (they are the same thing, + // but the passive health checker is a default value for for upstreams + // without MaxRequests), copy the value into this upstream, since the + // value in the upstream (MaxRequests) is what is used during + // availability checks + if h.HealthChecks != nil && h.HealthChecks.Passive != nil { + h.HealthChecks.Passive.logger = h.logger.Named("health_checker.passive") + if h.HealthChecks.Passive.UnhealthyRequestCount > 0 && + upstream.MaxRequests == 0 { + upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount + } + } + + // upstreams need independent access to the passive + // health check policy because passive health checks + // run without access to h. + if h.HealthChecks != nil { + upstream.healthCheckPolicy = h.HealthChecks.Passive + } +} + // bufferedBody reads originalBody into a buffer, then returns a reader for the buffer. // Always close the return value when done with it, just like if it was the original body! func (h Handler) bufferedBody(originalBody io.ReadCloser) io.ReadCloser { @@ -1085,6 +1151,20 @@ type Selector interface { Select(UpstreamPool, *http.Request, http.ResponseWriter) *Upstream } +// UpstreamSource gets the list of upstreams that can be used when +// proxying a request. Returned upstreams will be load balanced and +// health-checked. This should be a very fast function -- instant +// if possible -- and the return value must be as stable as possible. +// In other words, the list of upstreams should ideally not change much +// across successive calls. If the list of upstreams changes or the +// ordering is not stable, load balancing will suffer. This function +// may be called during each retry, multiple times per request, and as +// such, needs to be instantaneous. The returned slice will not be +// modified. +type UpstreamSource interface { + GetUpstreams(*http.Request) ([]*Upstream, error) +} + // Hop-by-hop headers. These are removed when sent to the backend. // As of RFC 7230, hop-by-hop headers are required to appear in the // Connection header field. These are the headers defined by the diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go index c28799d..7175f77 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go @@ -22,9 +22,9 @@ import ( func testPool() UpstreamPool { return UpstreamPool{ - {Host: new(upstreamHost)}, - {Host: new(upstreamHost)}, - {Host: new(upstreamHost)}, + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, } } @@ -48,20 +48,20 @@ func TestRoundRobinPolicy(t *testing.T) { t.Error("Expected third round robin host to be first host in the pool.") } // mark host as down - pool[1].SetHealthy(false) + pool[1].setHealthy(false) h = rrPolicy.Select(pool, req, nil) if h != pool[2] { t.Error("Expected to skip down host.") } // mark host as up - pool[1].SetHealthy(true) + pool[1].setHealthy(true) h = rrPolicy.Select(pool, req, nil) if h == pool[2] { t.Error("Expected to balance evenly among healthy hosts") } // mark host as full - pool[1].CountRequest(1) + pool[1].countRequest(1) pool[1].MaxRequests = 1 h = rrPolicy.Select(pool, req, nil) if h != pool[2] { @@ -74,13 +74,13 @@ func TestLeastConnPolicy(t *testing.T) { lcPolicy := new(LeastConnSelection) req, _ := http.NewRequest("GET", "/", nil) - pool[0].CountRequest(10) - pool[1].CountRequest(10) + pool[0].countRequest(10) + pool[1].countRequest(10) h := lcPolicy.Select(pool, req, nil) if h != pool[2] { t.Error("Expected least connection host to be third host.") } - pool[2].CountRequest(100) + pool[2].countRequest(100) h = lcPolicy.Select(pool, req, nil) if h != pool[0] && h != pool[1] { t.Error("Expected least connection host to be first or second host.") @@ -139,7 +139,7 @@ func TestIPHashPolicy(t *testing.T) { // we should get a healthy host if the original host is unhealthy and a // healthy host is available req.RemoteAddr = "172.0.0.1" - pool[1].SetHealthy(false) + pool[1].setHealthy(false) h = ipHash.Select(pool, req, nil) if h != pool[2] { t.Error("Expected ip hash policy host to be the third host.") @@ -150,10 +150,10 @@ func TestIPHashPolicy(t *testing.T) { if h != pool[2] { t.Error("Expected ip hash policy host to be the third host.") } - pool[1].SetHealthy(true) + pool[1].setHealthy(true) req.RemoteAddr = "172.0.0.3" - pool[2].SetHealthy(false) + pool[2].setHealthy(false) h = ipHash.Select(pool, req, nil) if h != pool[0] { t.Error("Expected ip hash policy host to be the first host.") @@ -167,8 +167,8 @@ func TestIPHashPolicy(t *testing.T) { // We should be able to resize the host pool and still be able to predict // where a req will be routed with the same IP's used above pool = UpstreamPool{ - {Host: new(upstreamHost)}, - {Host: new(upstreamHost)}, + {Host: new(Host)}, + {Host: new(Host)}, } req.RemoteAddr = "172.0.0.1:80" h = ipHash.Select(pool, req, nil) @@ -192,8 +192,8 @@ func TestIPHashPolicy(t *testing.T) { } // We should get nil when there are no healthy hosts - pool[0].SetHealthy(false) - pool[1].SetHealthy(false) + pool[0].setHealthy(false) + pool[1].setHealthy(false) h = ipHash.Select(pool, req, nil) if h != nil { t.Error("Expected ip hash policy host to be nil.") @@ -201,25 +201,25 @@ func TestIPHashPolicy(t *testing.T) { // Reproduce #4135 pool = UpstreamPool{ - {Host: new(upstreamHost)}, - {Host: new(upstreamHost)}, - {Host: new(upstreamHost)}, - {Host: new(upstreamHost)}, - {Host: new(upstreamHost)}, - {Host: new(upstreamHost)}, - {Host: new(upstreamHost)}, - {Host: new(upstreamHost)}, - {Host: new(upstreamHost)}, - } - pool[0].SetHealthy(false) - pool[1].SetHealthy(false) - pool[2].SetHealthy(false) - pool[3].SetHealthy(false) - pool[4].SetHealthy(false) - pool[5].SetHealthy(false) - pool[6].SetHealthy(false) - pool[7].SetHealthy(false) - pool[8].SetHealthy(true) + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, + {Host: new(Host)}, + } + pool[0].setHealthy(false) + pool[1].setHealthy(false) + pool[2].setHealthy(false) + pool[3].setHealthy(false) + pool[4].setHealthy(false) + pool[5].setHealthy(false) + pool[6].setHealthy(false) + pool[7].setHealthy(false) + pool[8].setHealthy(true) // We should get a result back when there is one healthy host left. h = ipHash.Select(pool, req, nil) @@ -239,7 +239,7 @@ func TestFirstPolicy(t *testing.T) { t.Error("Expected first policy host to be the first host.") } - pool[0].SetHealthy(false) + pool[0].setHealthy(false) h = firstPolicy.Select(pool, req, nil) if h != pool[1] { t.Error("Expected first policy host to be the second host.") @@ -256,7 +256,7 @@ func TestURIHashPolicy(t *testing.T) { t.Error("Expected uri policy host to be the first host.") } - pool[0].SetHealthy(false) + pool[0].setHealthy(false) h = uriPolicy.Select(pool, request, nil) if h != pool[1] { t.Error("Expected uri policy host to be the first host.") @@ -271,8 +271,8 @@ func TestURIHashPolicy(t *testing.T) { // We should be able to resize the host pool and still be able to predict // where a request will be routed with the same URI's used above pool = UpstreamPool{ - {Host: new(upstreamHost)}, - {Host: new(upstreamHost)}, + {Host: new(Host)}, + {Host: new(Host)}, } request = httptest.NewRequest(http.MethodGet, "/test", nil) @@ -281,7 +281,7 @@ func TestURIHashPolicy(t *testing.T) { t.Error("Expected uri policy host to be the first host.") } - pool[0].SetHealthy(false) + pool[0].setHealthy(false) h = uriPolicy.Select(pool, request, nil) if h != pool[1] { t.Error("Expected uri policy host to be the first host.") @@ -293,8 +293,8 @@ func TestURIHashPolicy(t *testing.T) { t.Error("Expected uri policy host to be the second host.") } - pool[0].SetHealthy(false) - pool[1].SetHealthy(false) + pool[0].setHealthy(false) + pool[1].setHealthy(false) h = uriPolicy.Select(pool, request, nil) if h != nil { t.Error("Expected uri policy policy host to be nil.") @@ -306,12 +306,12 @@ func TestLeastRequests(t *testing.T) { pool[0].Dial = "localhost:8080" pool[1].Dial = "localhost:8081" pool[2].Dial = "localhost:8082" - pool[0].SetHealthy(true) - pool[1].SetHealthy(true) - pool[2].SetHealthy(true) - pool[0].CountRequest(10) - pool[1].CountRequest(20) - pool[2].CountRequest(30) + pool[0].setHealthy(true) + pool[1].setHealthy(true) + pool[2].setHealthy(true) + pool[0].countRequest(10) + pool[1].countRequest(20) + pool[2].countRequest(30) result := leastRequests(pool) @@ -329,12 +329,12 @@ func TestRandomChoicePolicy(t *testing.T) { pool[0].Dial = "localhost:8080" pool[1].Dial = "localhost:8081" pool[2].Dial = "localhost:8082" - pool[0].SetHealthy(false) - pool[1].SetHealthy(true) - pool[2].SetHealthy(true) - pool[0].CountRequest(10) - pool[1].CountRequest(20) - pool[2].CountRequest(30) + pool[0].setHealthy(false) + pool[1].setHealthy(true) + pool[2].setHealthy(true) + pool[0].countRequest(10) + pool[1].countRequest(20) + pool[2].countRequest(30) request := httptest.NewRequest(http.MethodGet, "/test", nil) randomChoicePolicy := new(RandomChoiceSelection) @@ -357,9 +357,9 @@ func TestCookieHashPolicy(t *testing.T) { pool[0].Dial = "localhost:8080" pool[1].Dial = "localhost:8081" pool[2].Dial = "localhost:8082" - pool[0].SetHealthy(true) - pool[1].SetHealthy(false) - pool[2].SetHealthy(false) + pool[0].setHealthy(true) + pool[1].setHealthy(false) + pool[2].setHealthy(false) request := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() cookieHashPolicy := new(CookieHashSelection) @@ -374,8 +374,8 @@ func TestCookieHashPolicy(t *testing.T) { if h != pool[0] { t.Error("Expected cookieHashPolicy host to be the first only available host.") } - pool[1].SetHealthy(true) - pool[2].SetHealthy(true) + pool[1].setHealthy(true) + pool[2].setHealthy(true) request = httptest.NewRequest(http.MethodGet, "/test", nil) w = httptest.NewRecorder() request.AddCookie(cookieServer1) @@ -387,7 +387,7 @@ func TestCookieHashPolicy(t *testing.T) { if len(s) != 0 { t.Error("Expected cookieHashPolicy to not set a new cookie.") } - pool[0].SetHealthy(false) + pool[0].setHealthy(false) request = httptest.NewRequest(http.MethodGet, "/test", nil) w = httptest.NewRecorder() request.AddCookie(cookieServer1) diff --git a/modules/caddyhttp/reverseproxy/upstreams.go b/modules/caddyhttp/reverseproxy/upstreams.go new file mode 100644 index 0000000..eb5845f --- /dev/null +++ b/modules/caddyhttp/reverseproxy/upstreams.go @@ -0,0 +1,377 @@ +package reverseproxy + +import ( + "context" + "fmt" + weakrand "math/rand" + "net" + "net/http" + "strconv" + "sync" + "time" + + "github.com/caddyserver/caddy/v2" + "go.uber.org/zap" +) + +func init() { + caddy.RegisterModule(SRVUpstreams{}) + caddy.RegisterModule(AUpstreams{}) +} + +// SRVUpstreams provides upstreams from SRV lookups. +// The lookup DNS name can be configured either by +// its individual parts (that is, specifying the +// service, protocol, and name separately) to form +// the standard "_service._proto.name" domain, or +// the domain can be specified directly in name by +// leaving service and proto empty. See RFC 2782. +// +// Lookups are cached and refreshed at the configured +// refresh interval. +// +// Returned upstreams are sorted by priority and weight. +type SRVUpstreams struct { + // The service label. + Service string `json:"service,omitempty"` + + // The protocol label; either tcp or udp. + Proto string `json:"proto,omitempty"` + + // The name label; or, if service and proto are + // empty, the entire domain name to look up. + Name string `json:"name,omitempty"` + + // The interval at which to refresh the SRV lookup. + // Results are cached between lookups. Default: 1m + Refresh caddy.Duration `json:"refresh,omitempty"` + + // Configures the DNS resolver used to resolve the + // SRV address to SRV records. + Resolver *UpstreamResolver `json:"resolver,omitempty"` + + // If Resolver is configured, how long to wait before + // timing out trying to connect to the DNS server. + DialTimeout caddy.Duration `json:"dial_timeout,omitempty"` + + // If Resolver is configured, how long to wait before + // spawning an RFC 6555 Fast Fallback connection. + // A negative value disables this. + FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"` + + resolver *net.Resolver + + logger *zap.Logger +} + +// CaddyModule returns the Caddy module information. +func (SRVUpstreams) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "http.reverse_proxy.upstreams.srv", + New: func() caddy.Module { return new(SRVUpstreams) }, + } +} + +// String returns the RFC 2782 representation of the SRV domain. +func (su SRVUpstreams) String() string { + return fmt.Sprintf("_%s._%s.%s", su.Service, su.Proto, su.Name) +} + +func (su *SRVUpstreams) Provision(ctx caddy.Context) error { + su.logger = ctx.Logger(su) + if su.Refresh == 0 { + su.Refresh = caddy.Duration(time.Minute) + } + + if su.Resolver != nil { + err := su.Resolver.ParseAddresses() + if err != nil { + return err + } + d := &net.Dialer{ + Timeout: time.Duration(su.DialTimeout), + FallbackDelay: time.Duration(su.FallbackDelay), + } + su.resolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { + //nolint:gosec + addr := su.Resolver.netAddrs[weakrand.Intn(len(su.Resolver.netAddrs))] + return d.DialContext(ctx, addr.Network, addr.JoinHostPort(0)) + }, + } + } + if su.resolver == nil { + su.resolver = net.DefaultResolver + } + + return nil +} + +func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) { + suStr := su.String() + + // first, use a cheap read-lock to return a cached result quickly + srvsMu.RLock() + cached := srvs[suStr] + srvsMu.RUnlock() + if cached.isFresh() { + return cached.upstreams, nil + } + + // otherwise, obtain a write-lock to update the cached value + srvsMu.Lock() + defer srvsMu.Unlock() + + // check to see if it's still stale, since we're now in a different + // lock from when we first checked freshness; another goroutine might + // have refreshed it in the meantime before we re-obtained our lock + cached = srvs[suStr] + if cached.isFresh() { + return cached.upstreams, nil + } + + // prepare parameters and perform the SRV lookup + repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) + service := repl.ReplaceAll(su.Service, "") + proto := repl.ReplaceAll(su.Proto, "") + name := repl.ReplaceAll(su.Name, "") + + su.logger.Debug("refreshing SRV upstreams", + zap.String("service", service), + zap.String("proto", proto), + zap.String("name", name)) + + _, records, err := su.resolver.LookupSRV(r.Context(), service, proto, name) + if err != nil { + // From LookupSRV docs: "If the response contains invalid names, those records are filtered + // out and an error will be returned alongside the the remaining results, if any." Thus, we + // only return an error if no records were also returned. + if len(records) == 0 { + return nil, err + } + su.logger.Warn("SRV records filtered", zap.Error(err)) + } + + upstreams := make([]*Upstream, len(records)) + for i, rec := range records { + su.logger.Debug("discovered SRV record", + zap.String("target", rec.Target), + zap.Uint16("port", rec.Port), + zap.Uint16("priority", rec.Priority), + zap.Uint16("weight", rec.Weight)) + addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port))) + upstreams[i] = &Upstream{Dial: addr} + } + + // before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full + if cached.freshness.IsZero() && len(srvs) >= 100 { + for randomKey := range srvs { + delete(srvs, randomKey) + break + } + } + + srvs[suStr] = srvLookup{ + srvUpstreams: su, + freshness: time.Now(), + upstreams: upstreams, + } + + return upstreams, nil +} + +type srvLookup struct { + srvUpstreams SRVUpstreams + freshness time.Time + upstreams []*Upstream +} + +func (sl srvLookup) isFresh() bool { + return time.Since(sl.freshness) < time.Duration(sl.srvUpstreams.Refresh) +} + +var ( + srvs = make(map[string]srvLookup) + srvsMu sync.RWMutex +) + +// AUpstreams provides upstreams from A/AAAA lookups. +// Results are cached and refreshed at the configured +// refresh interval. +type AUpstreams struct { + // The domain name to look up. + Name string `json:"name,omitempty"` + + // The port to use with the upstreams. Default: 80 + Port string `json:"port,omitempty"` + + // The interval at which to refresh the A lookup. + // Results are cached between lookups. Default: 1m + Refresh caddy.Duration `json:"refresh,omitempty"` + + // Configures the DNS resolver used to resolve the + // domain name to A records. + Resolver *UpstreamResolver `json:"resolver,omitempty"` + + // If Resolver is configured, how long to wait before + // timing out trying to connect to the DNS server. + DialTimeout caddy.Duration `json:"dial_timeout,omitempty"` + + // If Resolver is configured, how long to wait before + // spawning an RFC 6555 Fast Fallback connection. + // A negative value disables this. + FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"` + + resolver *net.Resolver +} + +// CaddyModule returns the Caddy module information. +func (AUpstreams) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "http.reverse_proxy.upstreams.a", + New: func() caddy.Module { return new(AUpstreams) }, + } +} + +func (au AUpstreams) String() string { return au.Name } + +func (au *AUpstreams) Provision(_ caddy.Context) error { + if au.Refresh == 0 { + au.Refresh = caddy.Duration(time.Minute) + } + if au.Port == "" { + au.Port = "80" + } + + if au.Resolver != nil { + err := au.Resolver.ParseAddresses() + if err != nil { + return err + } + d := &net.Dialer{ + Timeout: time.Duration(au.DialTimeout), + FallbackDelay: time.Duration(au.FallbackDelay), + } + au.resolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { + //nolint:gosec + addr := au.Resolver.netAddrs[weakrand.Intn(len(au.Resolver.netAddrs))] + return d.DialContext(ctx, addr.Network, addr.JoinHostPort(0)) + }, + } + } + if au.resolver == nil { + au.resolver = net.DefaultResolver + } + + return nil +} + +func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) { + auStr := au.String() + + // first, use a cheap read-lock to return a cached result quickly + aAaaaMu.RLock() + cached := aAaaa[auStr] + aAaaaMu.RUnlock() + if cached.isFresh() { + return cached.upstreams, nil + } + + // otherwise, obtain a write-lock to update the cached value + aAaaaMu.Lock() + defer aAaaaMu.Unlock() + + // check to see if it's still stale, since we're now in a different + // lock from when we first checked freshness; another goroutine might + // have refreshed it in the meantime before we re-obtained our lock + cached = aAaaa[auStr] + if cached.isFresh() { + return cached.upstreams, nil + } + + repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) + name := repl.ReplaceAll(au.Name, "") + port := repl.ReplaceAll(au.Port, "") + + ips, err := au.resolver.LookupIPAddr(r.Context(), name) + if err != nil { + return nil, err + } + + upstreams := make([]*Upstream, len(ips)) + for i, ip := range ips { + upstreams[i] = &Upstream{ + Dial: net.JoinHostPort(ip.String(), port), + } + } + + // before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full + if cached.freshness.IsZero() && len(srvs) >= 100 { + for randomKey := range aAaaa { + delete(aAaaa, randomKey) + break + } + } + + aAaaa[auStr] = aLookup{ + aUpstreams: au, + freshness: time.Now(), + upstreams: upstreams, + } + + return upstreams, nil +} + +type aLookup struct { + aUpstreams AUpstreams + freshness time.Time + upstreams []*Upstream +} + +func (al aLookup) isFresh() bool { + return time.Since(al.freshness) < time.Duration(al.aUpstreams.Refresh) +} + +// UpstreamResolver holds the set of addresses of DNS resolvers of +// upstream addresses +type UpstreamResolver struct { + // The addresses of DNS resolvers to use when looking up the addresses of proxy upstreams. + // It accepts [network addresses](/docs/conventions#network-addresses) + // with port range of only 1. If the host is an IP address, it will be dialed directly to resolve the upstream server. + // If the host is not an IP address, the addresses are resolved using the [name resolution convention](https://golang.org/pkg/net/#hdr-Name_Resolution) of the Go standard library. + // If the array contains more than 1 resolver address, one is chosen at random. + Addresses []string `json:"addresses,omitempty"` + netAddrs []caddy.NetworkAddress +} + +// ParseAddresses parses all the configured network addresses +// and ensures they're ready to be used. +func (u *UpstreamResolver) ParseAddresses() error { + for _, v := range u.Addresses { + addr, err := caddy.ParseNetworkAddress(v) + if err != nil { + return err + } + if addr.PortRangeSize() != 1 { + return fmt.Errorf("resolver address must have exactly one address; cannot call %v", addr) + } + u.netAddrs = append(u.netAddrs, addr) + } + return nil +} + +var ( + aAaaa = make(map[string]aLookup) + aAaaaMu sync.RWMutex +) + +// Interface guards +var ( + _ caddy.Provisioner = (*SRVUpstreams)(nil) + _ UpstreamSource = (*SRVUpstreams)(nil) + _ caddy.Provisioner = (*AUpstreams)(nil) + _ UpstreamSource = (*AUpstreams)(nil) +) |