summaryrefslogtreecommitdiff
path: root/modules/caddyhttp/reverseproxy
diff options
context:
space:
mode:
Diffstat (limited to 'modules/caddyhttp/reverseproxy')
-rw-r--r--modules/caddyhttp/reverseproxy/admin.go3
-rw-r--r--modules/caddyhttp/reverseproxy/caddyfile.go221
-rw-r--r--modules/caddyhttp/reverseproxy/healthchecks.go96
-rw-r--r--modules/caddyhttp/reverseproxy/hosts.go99
-rw-r--r--modules/caddyhttp/reverseproxy/httptransport.go24
-rw-r--r--modules/caddyhttp/reverseproxy/reverseproxy.go274
-rw-r--r--modules/caddyhttp/reverseproxy/selectionpolicies_test.go120
-rw-r--r--modules/caddyhttp/reverseproxy/upstreams.go377
8 files changed, 933 insertions, 281 deletions
diff --git a/modules/caddyhttp/reverseproxy/admin.go b/modules/caddyhttp/reverseproxy/admin.go
index 25685a3..81ec435 100644
--- a/modules/caddyhttp/reverseproxy/admin.go
+++ b/modules/caddyhttp/reverseproxy/admin.go
@@ -87,7 +87,7 @@ func (adminUpstreams) handleUpstreams(w http.ResponseWriter, r *http.Request) er
return false
}
- upstream, ok := val.(*upstreamHost)
+ upstream, ok := val.(*Host)
if !ok {
rangeErr = caddy.APIError{
HTTPStatus: http.StatusInternalServerError,
@@ -98,7 +98,6 @@ func (adminUpstreams) handleUpstreams(w http.ResponseWriter, r *http.Request) er
results = append(results, upstreamStatus{
Address: address,
- Healthy: !upstream.Unhealthy(),
NumRequests: upstream.NumRequests(),
Fails: upstream.Fails(),
})
diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go
index e127237..f4b1636 100644
--- a/modules/caddyhttp/reverseproxy/caddyfile.go
+++ b/modules/caddyhttp/reverseproxy/caddyfile.go
@@ -53,6 +53,7 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
// reverse_proxy [<matcher>] [<upstreams...>] {
// # upstreams
// to <upstreams...>
+// dynamic <name> [...]
//
// # load balancing
// lb_policy <name> [<options...>]
@@ -190,6 +191,25 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
}
+ case "dynamic":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.DynamicUpstreams != nil {
+ return d.Err("dynamic upstreams already specified")
+ }
+ dynModule := d.Val()
+ modID := "http.reverse_proxy.upstreams." + dynModule
+ unm, err := caddyfile.UnmarshalModule(d, modID)
+ if err != nil {
+ return err
+ }
+ source, ok := unm.(UpstreamSource)
+ if !ok {
+ return d.Errf("module %s (%T) is not an UpstreamSource", modID, unm)
+ }
+ h.DynamicUpstreamsRaw = caddyconfig.JSONModuleObject(source, "source", dynModule, nil)
+
case "lb_policy":
if !d.NextArg() {
return d.ArgErr()
@@ -749,6 +769,7 @@ func (h *Handler) FinalizeUnmarshalCaddyfile(helper httpcaddyfile.Helper) error
// dial_fallback_delay <duration>
// response_header_timeout <duration>
// expect_continue_timeout <duration>
+// resolvers <resolvers...>
// tls
// tls_client_auth <automate_name> | <cert_file> <key_file>
// tls_insecure_skip_verify
@@ -839,6 +860,15 @@ func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
h.ExpectContinueTimeout = caddy.Duration(dur)
+ case "resolvers":
+ if h.Resolver == nil {
+ h.Resolver = new(UpstreamResolver)
+ }
+ h.Resolver.Addresses = d.RemainingArgs()
+ if len(h.Resolver.Addresses) == 0 {
+ return d.Errf("must specify at least one resolver address")
+ }
+
case "tls_client_auth":
if h.TLS == nil {
h.TLS = new(TLSConfig)
@@ -989,10 +1019,201 @@ func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return nil
}
+// UnmarshalCaddyfile deserializes Caddyfile tokens into h.
+//
+// dynamic srv [<name>] {
+// service <service>
+// proto <proto>
+// name <name>
+// refresh <interval>
+// resolvers <resolvers...>
+// dial_timeout <timeout>
+// dial_fallback_delay <timeout>
+// }
+//
+func (u *SRVUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
+ for d.Next() {
+ args := d.RemainingArgs()
+ if len(args) > 1 {
+ return d.ArgErr()
+ }
+ if len(args) > 0 {
+ u.Name = args[0]
+ }
+
+ for d.NextBlock(0) {
+ switch d.Val() {
+ case "service":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if u.Service != "" {
+ return d.Errf("srv service has already been specified")
+ }
+ u.Service = d.Val()
+
+ case "proto":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if u.Proto != "" {
+ return d.Errf("srv proto has already been specified")
+ }
+ u.Proto = d.Val()
+
+ case "name":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if u.Name != "" {
+ return d.Errf("srv name has already been specified")
+ }
+ u.Name = d.Val()
+
+ case "refresh":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ dur, err := caddy.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("parsing refresh interval duration: %v", err)
+ }
+ u.Refresh = caddy.Duration(dur)
+
+ case "resolvers":
+ if u.Resolver == nil {
+ u.Resolver = new(UpstreamResolver)
+ }
+ u.Resolver.Addresses = d.RemainingArgs()
+ if len(u.Resolver.Addresses) == 0 {
+ return d.Errf("must specify at least one resolver address")
+ }
+
+ case "dial_timeout":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ dur, err := caddy.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad timeout value '%s': %v", d.Val(), err)
+ }
+ u.DialTimeout = caddy.Duration(dur)
+
+ case "dial_fallback_delay":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ dur, err := caddy.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad delay value '%s': %v", d.Val(), err)
+ }
+ u.FallbackDelay = caddy.Duration(dur)
+
+ default:
+ return d.Errf("unrecognized srv option '%s'", d.Val())
+ }
+ }
+ }
+
+ return nil
+}
+
+// UnmarshalCaddyfile deserializes Caddyfile tokens into h.
+//
+// dynamic a [<name> <port] {
+// name <name>
+// port <port>
+// refresh <interval>
+// resolvers <resolvers...>
+// dial_timeout <timeout>
+// dial_fallback_delay <timeout>
+// }
+//
+func (u *AUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
+ for d.Next() {
+ args := d.RemainingArgs()
+ if len(args) > 2 {
+ return d.ArgErr()
+ }
+ if len(args) > 0 {
+ u.Name = args[0]
+ u.Port = args[1]
+ }
+
+ for d.NextBlock(0) {
+ switch d.Val() {
+ case "name":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if u.Name != "" {
+ return d.Errf("a name has already been specified")
+ }
+ u.Name = d.Val()
+
+ case "port":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if u.Port != "" {
+ return d.Errf("a port has already been specified")
+ }
+ u.Port = d.Val()
+
+ case "refresh":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ dur, err := caddy.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("parsing refresh interval duration: %v", err)
+ }
+ u.Refresh = caddy.Duration(dur)
+
+ case "resolvers":
+ if u.Resolver == nil {
+ u.Resolver = new(UpstreamResolver)
+ }
+ u.Resolver.Addresses = d.RemainingArgs()
+ if len(u.Resolver.Addresses) == 0 {
+ return d.Errf("must specify at least one resolver address")
+ }
+
+ case "dial_timeout":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ dur, err := caddy.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad timeout value '%s': %v", d.Val(), err)
+ }
+ u.DialTimeout = caddy.Duration(dur)
+
+ case "dial_fallback_delay":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ dur, err := caddy.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad delay value '%s': %v", d.Val(), err)
+ }
+ u.FallbackDelay = caddy.Duration(dur)
+
+ default:
+ return d.Errf("unrecognized srv option '%s'", d.Val())
+ }
+ }
+ }
+
+ return nil
+}
+
const matcherPrefix = "@"
// Interface guards
var (
_ caddyfile.Unmarshaler = (*Handler)(nil)
_ caddyfile.Unmarshaler = (*HTTPTransport)(nil)
+ _ caddyfile.Unmarshaler = (*SRVUpstreams)(nil)
+ _ caddyfile.Unmarshaler = (*AUpstreams)(nil)
)
diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go
index 230bf3a..317b283 100644
--- a/modules/caddyhttp/reverseproxy/healthchecks.go
+++ b/modules/caddyhttp/reverseproxy/healthchecks.go
@@ -18,7 +18,6 @@ import (
"context"
"fmt"
"io"
- "log"
"net"
"net/http"
"net/url"
@@ -37,12 +36,32 @@ import (
type HealthChecks struct {
// Active health checks run in the background on a timer. To
// minimally enable active health checks, set either path or
- // port (or both).
+ // port (or both). Note that active health check status
+ // (healthy/unhealthy) is stored per-proxy-handler, not
+ // globally; this allows different handlers to use different
+ // criteria to decide what defines a healthy backend.
+ //
+ // Active health checks do not run for dynamic upstreams.
Active *ActiveHealthChecks `json:"active,omitempty"`
// Passive health checks monitor proxied requests for errors or timeouts.
// To minimally enable passive health checks, specify at least an empty
- // config object.
+ // config object. Passive health check state is shared (stored globally),
+ // so a failure from one handler will be counted by all handlers; but
+ // the tolerances or standards for what defines healthy/unhealthy backends
+ // is configured per-proxy-handler.
+ //
+ // Passive health checks technically do operate on dynamic upstreams,
+ // but are only effective for very busy proxies where the list of
+ // upstreams is mostly stable. This is because the shared/global
+ // state of upstreams is cleaned up when the upstreams are no longer
+ // used. Since dynamic upstreams are allocated dynamically at each
+ // request (specifically, each iteration of the proxy loop per request),
+ // they are also cleaned up after every request. Thus, if there is a
+ // moment when no requests are actively referring to a particular
+ // upstream host, the passive health check state will be reset because
+ // it will be garbage-collected. It is usually better for the dynamic
+ // upstream module to only return healthy, available backends instead.
Passive *PassiveHealthChecks `json:"passive,omitempty"`
}
@@ -50,8 +69,7 @@ type HealthChecks struct {
// health checks (that is, health checks which occur in a
// background goroutine independently).
type ActiveHealthChecks struct {
- // The path to use for health checks.
- // DEPRECATED: Use 'uri' instead.
+ // DEPRECATED: Use 'uri' instead. This field will be removed. TODO: remove this field
Path string `json:"path,omitempty"`
// The URI (path and query) to use for health checks
@@ -132,7 +150,9 @@ type CircuitBreaker interface {
func (h *Handler) activeHealthChecker() {
defer func() {
if err := recover(); err != nil {
- log.Printf("[PANIC] active health checks: %v\n%s", err, debug.Stack())
+ h.HealthChecks.Active.logger.Error("active health checker panicked",
+ zap.Any("error", err),
+ zap.ByteString("stack", debug.Stack()))
}
}()
ticker := time.NewTicker(time.Duration(h.HealthChecks.Active.Interval))
@@ -155,7 +175,9 @@ func (h *Handler) doActiveHealthCheckForAllHosts() {
go func(upstream *Upstream) {
defer func() {
if err := recover(); err != nil {
- log.Printf("[PANIC] active health check: %v\n%s", err, debug.Stack())
+ h.HealthChecks.Active.logger.Error("active health check panicked",
+ zap.Any("error", err),
+ zap.ByteString("stack", debug.Stack()))
}
}()
@@ -195,7 +217,7 @@ func (h *Handler) doActiveHealthCheckForAllHosts() {
// so use a fake Host value instead; unix sockets are usually local
hostAddr = "localhost"
}
- err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: dialAddr}, hostAddr, upstream.Host)
+ err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: dialAddr}, hostAddr, upstream)
if err != nil {
h.HealthChecks.Active.logger.Error("active health check failed",
zap.String("address", hostAddr),
@@ -206,14 +228,14 @@ func (h *Handler) doActiveHealthCheckForAllHosts() {
}
}
-// doActiveHealthCheck performs a health check to host which
+// doActiveHealthCheck performs a health check to upstream which
// can be reached at address hostAddr. The actual address for
// the request will be built according to active health checker
// config. The health status of the host will be updated
// according to whether it passes the health check. An error is
// returned only if the health check fails to occur or if marking
// the host's health status fails.
-func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host Host) error {
+func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstream *Upstream) error {
// create the URL for the request that acts as a health check
scheme := "http"
if ht, ok := h.Transport.(TLSTransport); ok && ht.TLSEnabled() {
@@ -269,10 +291,7 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host H
zap.String("host", hostAddr),
zap.Error(err),
)
- _, err2 := host.SetHealthy(false)
- if err2 != nil {
- return fmt.Errorf("marking unhealthy: %v", err2)
- }
+ upstream.setHealthy(false)
return nil
}
var body io.Reader = resp.Body
@@ -292,10 +311,7 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host H
zap.Int("status_code", resp.StatusCode),
zap.String("host", hostAddr),
)
- _, err := host.SetHealthy(false)
- if err != nil {
- return fmt.Errorf("marking unhealthy: %v", err)
- }
+ upstream.setHealthy(false)
return nil
}
} else if resp.StatusCode < 200 || resp.StatusCode >= 400 {
@@ -303,10 +319,7 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host H
zap.Int("status_code", resp.StatusCode),
zap.String("host", hostAddr),
)
- _, err := host.SetHealthy(false)
- if err != nil {
- return fmt.Errorf("marking unhealthy: %v", err)
- }
+ upstream.setHealthy(false)
return nil
}
@@ -318,33 +331,21 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host H
zap.String("host", hostAddr),
zap.Error(err),
)
- _, err := host.SetHealthy(false)
- if err != nil {
- return fmt.Errorf("marking unhealthy: %v", err)
- }
+ upstream.setHealthy(false)
return nil
}
if !h.HealthChecks.Active.bodyRegexp.Match(bodyBytes) {
h.HealthChecks.Active.logger.Info("response body failed expectations",
zap.String("host", hostAddr),
)
- _, err := host.SetHealthy(false)
- if err != nil {
- return fmt.Errorf("marking unhealthy: %v", err)
- }
+ upstream.setHealthy(false)
return nil
}
}
// passed health check parameters, so mark as healthy
- swapped, err := host.SetHealthy(true)
- if swapped {
- h.HealthChecks.Active.logger.Info("host is up",
- zap.String("host", hostAddr),
- )
- }
- if err != nil {
- return fmt.Errorf("marking healthy: %v", err)
+ if upstream.setHealthy(true) {
+ h.HealthChecks.Active.logger.Info("host is up", zap.String("host", hostAddr))
}
return nil
@@ -366,7 +367,7 @@ func (h *Handler) countFailure(upstream *Upstream) {
}
// count failure immediately
- err := upstream.Host.CountFail(1)
+ err := upstream.Host.countFail(1)
if err != nil {
h.HealthChecks.Passive.logger.Error("could not count failure",
zap.String("host", upstream.Dial),
@@ -375,14 +376,23 @@ func (h *Handler) countFailure(upstream *Upstream) {
}
// forget it later
- go func(host Host, failDuration time.Duration) {
+ go func(host *Host, failDuration time.Duration) {
defer func() {
if err := recover(); err != nil {
- log.Printf("[PANIC] health check failure forgetter: %v\n%s", err, debug.Stack())
+ h.HealthChecks.Passive.logger.Error("passive health check failure forgetter panicked",
+ zap.Any("error", err),
+ zap.ByteString("stack", debug.Stack()))
}
}()
- time.Sleep(failDuration)
- err := host.CountFail(-1)
+ timer := time.NewTimer(failDuration)
+ select {
+ case <-h.ctx.Done():
+ if !timer.Stop() {
+ <-timer.C
+ }
+ case <-timer.C:
+ }
+ err := host.countFail(-1)
if err != nil {
h.HealthChecks.Passive.logger.Error("could not forget failure",
zap.String("host", upstream.Dial),
diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go
index b9817d2..a973ecb 100644
--- a/modules/caddyhttp/reverseproxy/hosts.go
+++ b/modules/caddyhttp/reverseproxy/hosts.go
@@ -26,44 +26,14 @@ import (
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
-// Host represents a remote host which can be proxied to.
-// Its methods must be safe for concurrent use.
-type Host interface {
- // NumRequests returns the number of requests
- // currently in process with the host.
- NumRequests() int
-
- // Fails returns the count of recent failures.
- Fails() int
-
- // Unhealthy returns true if the backend is unhealthy.
- Unhealthy() bool
-
- // CountRequest atomically counts the given number of
- // requests as currently in process with the host. The
- // count should not go below 0.
- CountRequest(int) error
-
- // CountFail atomically counts the given number of
- // failures with the host. The count should not go
- // below 0.
- CountFail(int) error
-
- // SetHealthy atomically marks the host as either
- // healthy (true) or unhealthy (false). If the given
- // status is the same, this should be a no-op and
- // return false. It returns true if the status was
- // changed; i.e. if it is now different from before.
- SetHealthy(bool) (bool, error)
-}
-
// UpstreamPool is a collection of upstreams.
type UpstreamPool []*Upstream
// Upstream bridges this proxy's configuration to the
// state of the backend host it is correlated with.
+// Upstream values must not be copied.
type Upstream struct {
- Host `json:"-"`
+ *Host `json:"-"`
// The [network address](/docs/conventions#network-addresses)
// to dial to connect to the upstream. Must represent precisely
@@ -77,6 +47,10 @@ type Upstream struct {
// backends is down. Also be aware of open proxy vulnerabilities.
Dial string `json:"dial,omitempty"`
+ // DEPRECATED: Use the SRVUpstreams module instead
+ // (http.reverse_proxy.upstreams.srv). This field will be
+ // removed in a future version of Caddy. TODO: Remove this field.
+ //
// If DNS SRV records are used for service discovery with this
// upstream, specify the DNS name for which to look up SRV
// records here, instead of specifying a dial address.
@@ -95,6 +69,7 @@ type Upstream struct {
activeHealthCheckPort int
healthCheckPolicy *PassiveHealthChecks
cb CircuitBreaker
+ unhealthy int32 // accessed atomically; status from active health checker
}
func (u Upstream) String() string {
@@ -117,7 +92,7 @@ func (u *Upstream) Available() bool {
// is currently known to be healthy or "up".
// It consults the circuit breaker, if any.
func (u *Upstream) Healthy() bool {
- healthy := !u.Host.Unhealthy()
+ healthy := u.healthy()
if healthy && u.healthCheckPolicy != nil {
healthy = u.Host.Fails() < u.healthCheckPolicy.MaxFails
}
@@ -142,7 +117,7 @@ func (u *Upstream) fillDialInfo(r *http.Request) (DialInfo, error) {
var addr caddy.NetworkAddress
if u.LookupSRV != "" {
- // perform DNS lookup for SRV records and choose one
+ // perform DNS lookup for SRV records and choose one - TODO: deprecated
srvName := repl.ReplaceAll(u.LookupSRV, "")
_, records, err := net.DefaultResolver.LookupSRV(r.Context(), "", "", srvName)
if err != nil {
@@ -174,59 +149,67 @@ func (u *Upstream) fillDialInfo(r *http.Request) (DialInfo, error) {
}, nil
}
-// upstreamHost is the basic, in-memory representation
-// of the state of a remote host. It implements the
-// Host interface.
-type upstreamHost struct {
+func (u *Upstream) fillHost() {
+ host := new(Host)
+ existingHost, loaded := hosts.LoadOrStore(u.String(), host)
+ if loaded {
+ host = existingHost.(*Host)
+ }
+ u.Host = host
+}
+
+// Host is the basic, in-memory representation of the state of a remote host.
+// Its fields are accessed atomically and Host values must not be copied.
+type Host struct {
numRequests int64 // must be 64-bit aligned on 32-bit systems (see https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
fails int64
- unhealthy int32
}
// NumRequests returns the number of active requests to the upstream.
-func (uh *upstreamHost) NumRequests() int {
- return int(atomic.LoadInt64(&uh.numRequests))
+func (h *Host) NumRequests() int {
+ return int(atomic.LoadInt64(&h.numRequests))
}
// Fails returns the number of recent failures with the upstream.
-func (uh *upstreamHost) Fails() int {
- return int(atomic.LoadInt64(&uh.fails))
-}
-
-// Unhealthy returns whether the upstream is healthy.
-func (uh *upstreamHost) Unhealthy() bool {
- return atomic.LoadInt32(&uh.unhealthy) == 1
+func (h *Host) Fails() int {
+ return int(atomic.LoadInt64(&h.fails))
}
-// CountRequest mutates the active request count by
+// countRequest mutates the active request count by
// delta. It returns an error if the adjustment fails.
-func (uh *upstreamHost) CountRequest(delta int) error {
- result := atomic.AddInt64(&uh.numRequests, int64(delta))
+func (h *Host) countRequest(delta int) error {
+ result := atomic.AddInt64(&h.numRequests, int64(delta))
if result < 0 {
return fmt.Errorf("count below 0: %d", result)
}
return nil
}
-// CountFail mutates the recent failures count by
+// countFail mutates the recent failures count by
// delta. It returns an error if the adjustment fails.
-func (uh *upstreamHost) CountFail(delta int) error {
- result := atomic.AddInt64(&uh.fails, int64(delta))
+func (h *Host) countFail(delta int) error {
+ result := atomic.AddInt64(&h.fails, int64(delta))
if result < 0 {
return fmt.Errorf("count below 0: %d", result)
}
return nil
}
+// healthy returns true if the upstream is not actively marked as unhealthy.
+// (This returns the status only from the "active" health checks.)
+func (u *Upstream) healthy() bool {
+ return atomic.LoadInt32(&u.unhealthy) == 0
+}
+
// SetHealthy sets the upstream has healthy or unhealthy
-// and returns true if the new value is different.
-func (uh *upstreamHost) SetHealthy(healthy bool) (bool, error) {
+// and returns true if the new value is different. This
+// sets the status only for the "active" health checks.
+func (u *Upstream) setHealthy(healthy bool) bool {
var unhealthy, compare int32 = 1, 0
if healthy {
unhealthy, compare = 0, 1
}
- swapped := atomic.CompareAndSwapInt32(&uh.unhealthy, compare, unhealthy)
- return swapped, nil
+ return atomic.CompareAndSwapInt32(&u.unhealthy, compare, unhealthy)
}
// DialInfo contains information needed to dial a
diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go
index 4be51af..f7472be 100644
--- a/modules/caddyhttp/reverseproxy/httptransport.go
+++ b/modules/caddyhttp/reverseproxy/httptransport.go
@@ -168,15 +168,9 @@ func (h *HTTPTransport) NewTransport(ctx caddy.Context) (*http.Transport, error)
}
if h.Resolver != nil {
- for _, v := range h.Resolver.Addresses {
- addr, err := caddy.ParseNetworkAddress(v)
- if err != nil {
- return nil, err
- }
- if addr.PortRangeSize() != 1 {
- return nil, fmt.Errorf("resolver address must have exactly one address; cannot call %v", addr)
- }
- h.Resolver.netAddrs = append(h.Resolver.netAddrs, addr)
+ err := h.Resolver.ParseAddresses()
+ if err != nil {
+ return nil, err
}
d := &net.Dialer{
Timeout: time.Duration(h.DialTimeout),
@@ -406,18 +400,6 @@ func (t TLSConfig) MakeTLSClientConfig(ctx caddy.Context) (*tls.Config, error) {
return cfg, nil
}
-// UpstreamResolver holds the set of addresses of DNS resolvers of
-// upstream addresses
-type UpstreamResolver struct {
- // The addresses of DNS resolvers to use when looking up the addresses of proxy upstreams.
- // It accepts [network addresses](/docs/conventions#network-addresses)
- // with port range of only 1. If the host is an IP address, it will be dialed directly to resolve the upstream server.
- // If the host is not an IP address, the addresses are resolved using the [name resolution convention](https://golang.org/pkg/net/#hdr-Name_Resolution) of the Go standard library.
- // If the array contains more than 1 resolver address, one is chosen at random.
- Addresses []string `json:"addresses,omitempty"`
- netAddrs []caddy.NetworkAddress
-}
-
// KeepAlive holds configuration pertaining to HTTP Keep-Alive.
type KeepAlive struct {
// Whether HTTP Keep-Alive is enabled. Default: true
diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go
index a5bdc31..3355f0b 100644
--- a/modules/caddyhttp/reverseproxy/reverseproxy.go
+++ b/modules/caddyhttp/reverseproxy/reverseproxy.go
@@ -78,9 +78,20 @@ type Handler struct {
// up or down. Down backends will not be proxied to.
HealthChecks *HealthChecks `json:"health_checks,omitempty"`
- // Upstreams is the list of backends to proxy to.
+ // Upstreams is the static list of backends to proxy to.
Upstreams UpstreamPool `json:"upstreams,omitempty"`
+ // A module for retrieving the list of upstreams dynamically. Dynamic
+ // upstreams are retrieved at every iteration of the proxy loop for
+ // each request (i.e. before every proxy attempt within every request).
+ // Active health checks do not work on dynamic upstreams, and passive
+ // health checks are only effective on dynamic upstreams if the proxy
+ // server is busy enough that concurrent requests to the same backends
+ // are continuous. Instead of health checks for dynamic upstreams, it
+ // is recommended that the dynamic upstream module only return available
+ // backends in the first place.
+ DynamicUpstreamsRaw json.RawMessage `json:"dynamic_upstreams,omitempty" caddy:"namespace=http.reverse_proxy.upstreams inline_key=source"`
+
// Adjusts how often to flush the response buffer. By default,
// no periodic flushing is done. A negative value disables
// response buffering, and flushes immediately after each
@@ -137,8 +148,9 @@ type Handler struct {
// - `{http.reverse_proxy.header.*}` The headers from the response
HandleResponse []caddyhttp.ResponseHandler `json:"handle_response,omitempty"`
- Transport http.RoundTripper `json:"-"`
- CB CircuitBreaker `json:"-"`
+ Transport http.RoundTripper `json:"-"`
+ CB CircuitBreaker `json:"-"`
+ DynamicUpstreams UpstreamSource `json:"-"`
// Holds the parsed CIDR ranges from TrustedProxies
trustedProxies []*net.IPNet
@@ -166,7 +178,7 @@ func (h *Handler) Provision(ctx caddy.Context) error {
h.ctx = ctx
h.logger = ctx.Logger(h)
- // verify SRV compatibility
+ // verify SRV compatibility - TODO: LookupSRV deprecated; will be removed
for i, v := range h.Upstreams {
if v.LookupSRV == "" {
continue
@@ -201,6 +213,13 @@ func (h *Handler) Provision(ctx caddy.Context) error {
}
h.CB = mod.(CircuitBreaker)
}
+ if h.DynamicUpstreamsRaw != nil {
+ mod, err := ctx.LoadModule(h, "DynamicUpstreamsRaw")
+ if err != nil {
+ return fmt.Errorf("loading upstream source module: %v", err)
+ }
+ h.DynamicUpstreams = mod.(UpstreamSource)
+ }
// parse trusted proxy CIDRs ahead of time
for _, str := range h.TrustedProxies {
@@ -270,38 +289,8 @@ func (h *Handler) Provision(ctx caddy.Context) error {
}
// set up upstreams
- for _, upstream := range h.Upstreams {
- // create or get the host representation for this upstream
- var host Host = new(upstreamHost)
- existingHost, loaded := hosts.LoadOrStore(upstream.String(), host)
- if loaded {
- host = existingHost.(Host)
- }
- upstream.Host = host
-
- // give it the circuit breaker, if any
- upstream.cb = h.CB
-
- // if the passive health checker has a non-zero UnhealthyRequestCount
- // but the upstream has no MaxRequests set (they are the same thing,
- // but the passive health checker is a default value for for upstreams
- // without MaxRequests), copy the value into this upstream, since the
- // value in the upstream (MaxRequests) is what is used during
- // availability checks
- if h.HealthChecks != nil && h.HealthChecks.Passive != nil {
- h.HealthChecks.Passive.logger = h.logger.Named("health_checker.passive")
- if h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
- upstream.MaxRequests == 0 {
- upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
- }
- }
-
- // upstreams need independent access to the passive
- // health check policy because passive health checks
- // run without access to h.
- if h.HealthChecks != nil {
- upstream.healthCheckPolicy = h.HealthChecks.Passive
- }
+ for _, u := range h.Upstreams {
+ h.provisionUpstream(u)
}
if h.HealthChecks != nil {
@@ -413,79 +402,127 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
repl.Set("http.reverse_proxy.duration", time.Since(start))
}()
+ // in the proxy loop, each iteration is an attempt to proxy the request,
+ // and because we may retry some number of times, carry over the error
+ // from previous tries because of the nuances of load balancing & retries
var proxyErr error
for {
- // choose an available upstream
- upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, clonedReq, w)
- if upstream == nil {
- if proxyErr == nil {
- proxyErr = fmt.Errorf("no upstreams available")
- }
- if !h.LoadBalancing.tryAgain(h.ctx, start, proxyErr, clonedReq) {
- break
- }
- continue
+ var done bool
+ done, proxyErr = h.proxyLoopIteration(clonedReq, w, proxyErr, start, repl, reqHeader, reqHost, next)
+ if done {
+ break
}
+ }
+
+ if proxyErr != nil {
+ return statusError(proxyErr)
+ }
- // the dial address may vary per-request if placeholders are
- // used, so perform those replacements here; the resulting
- // DialInfo struct should have valid network address syntax
- dialInfo, err := upstream.fillDialInfo(clonedReq)
+ return nil
+}
+
+// proxyLoopIteration implements an iteration of the proxy loop. Despite the enormous amount of local state
+// that has to be passed in, we brought this into its own method so that we could run defer more easily.
+// It returns true when the loop is done and should break; false otherwise. The error value returned should
+// be assigned to the proxyErr value for the next iteration of the loop (or the error handled after break).
+func (h *Handler) proxyLoopIteration(r *http.Request, w http.ResponseWriter, proxyErr error, start time.Time,
+ repl *caddy.Replacer, reqHeader http.Header, reqHost string, next caddyhttp.Handler) (bool, error) {
+ // get the updated list of upstreams
+ upstreams := h.Upstreams
+ if h.DynamicUpstreams != nil {
+ dUpstreams, err := h.DynamicUpstreams.GetUpstreams(r)
if err != nil {
- return statusError(fmt.Errorf("making dial info: %v", err))
+ h.logger.Error("failed getting dynamic upstreams; falling back to static upstreams", zap.Error(err))
+ } else {
+ upstreams = dUpstreams
+ for _, dUp := range dUpstreams {
+ h.provisionUpstream(dUp)
+ }
+ h.logger.Debug("provisioned dynamic upstreams", zap.Int("count", len(dUpstreams)))
+ defer func() {
+ // these upstreams are dynamic, so they are only used for this iteration
+ // of the proxy loop; be sure to let them go away when we're done with them
+ for _, upstream := range dUpstreams {
+ _, _ = hosts.Delete(upstream.String())
+ }
+ }()
}
+ }
- // attach to the request information about how to dial the upstream;
- // this is necessary because the information cannot be sufficiently
- // or satisfactorily represented in a URL
- caddyhttp.SetVar(r.Context(), dialInfoVarKey, dialInfo)
-
- // set placeholders with information about this upstream
- repl.Set("http.reverse_proxy.upstream.address", dialInfo.String())
- repl.Set("http.reverse_proxy.upstream.hostport", dialInfo.Address)
- repl.Set("http.reverse_proxy.upstream.host", dialInfo.Host)
- repl.Set("http.reverse_proxy.upstream.port", dialInfo.Port)
- repl.Set("http.reverse_proxy.upstream.requests", upstream.Host.NumRequests())
- repl.Set("http.reverse_proxy.upstream.max_requests", upstream.MaxRequests)
- repl.Set("http.reverse_proxy.upstream.fails", upstream.Host.Fails())
-
- // mutate request headers according to this upstream;
- // because we're in a retry loop, we have to copy
- // headers (and the Host value) from the original
- // so that each retry is identical to the first
- if h.Headers != nil && h.Headers.Request != nil {
- clonedReq.Header = make(http.Header)
- copyHeader(clonedReq.Header, reqHeader)
- clonedReq.Host = reqHost
- h.Headers.Request.ApplyToRequest(clonedReq)
+ // choose an available upstream
+ upstream := h.LoadBalancing.SelectionPolicy.Select(upstreams, r, w)
+ if upstream == nil {
+ if proxyErr == nil {
+ proxyErr = fmt.Errorf("no upstreams available")
}
-
- // proxy the request to that upstream
- proxyErr = h.reverseProxy(w, clonedReq, repl, dialInfo, next)
- if proxyErr == nil || proxyErr == context.Canceled {
- // context.Canceled happens when the downstream client
- // cancels the request, which is not our failure
- return nil
+ if !h.LoadBalancing.tryAgain(h.ctx, start, proxyErr, r) {
+ return true, proxyErr
}
+ return false, proxyErr
+ }
- // if the roundtrip was successful, don't retry the request or
- // ding the health status of the upstream (an error can still
- // occur after the roundtrip if, for example, a response handler
- // after the roundtrip returns an error)
- if succ, ok := proxyErr.(roundtripSucceeded); ok {
- return succ.error
- }
+ // the dial address may vary per-request if placeholders are
+ // used, so perform those replacements here; the resulting
+ // DialInfo struct should have valid network address syntax
+ dialInfo, err := upstream.fillDialInfo(r)
+ if err != nil {
+ return true, fmt.Errorf("making dial info: %v", err)
+ }
- // remember this failure (if enabled)
- h.countFailure(upstream)
+ h.logger.Debug("selected upstream",
+ zap.String("dial", dialInfo.Address),
+ zap.Int("total_upstreams", len(upstreams)))
- // if we've tried long enough, break
- if !h.LoadBalancing.tryAgain(h.ctx, start, proxyErr, clonedReq) {
- break
- }
+ // attach to the request information about how to dial the upstream;
+ // this is necessary because the information cannot be sufficiently
+ // or satisfactorily represented in a URL
+ caddyhttp.SetVar(r.Context(), dialInfoVarKey, dialInfo)
+
+ // set placeholders with information about this upstream
+ repl.Set("http.reverse_proxy.upstream.address", dialInfo.String())
+ repl.Set("http.reverse_proxy.upstream.hostport", dialInfo.Address)
+ repl.Set("http.reverse_proxy.upstream.host", dialInfo.Host)
+ repl.Set("http.reverse_proxy.upstream.port", dialInfo.Port)
+ repl.Set("http.reverse_proxy.upstream.requests", upstream.Host.NumRequests())
+ repl.Set("http.reverse_proxy.upstream.max_requests", upstream.MaxRequests)
+ repl.Set("http.reverse_proxy.upstream.fails", upstream.Host.Fails())
+
+ // mutate request headers according to this upstream;
+ // because we're in a retry loop, we have to copy
+ // headers (and the r.Host value) from the original
+ // so that each retry is identical to the first
+ if h.Headers != nil && h.Headers.Request != nil {
+ r.Header = make(http.Header)
+ copyHeader(r.Header, reqHeader)
+ r.Host = reqHost
+ h.Headers.Request.ApplyToRequest(r)
}
- return statusError(proxyErr)
+ // proxy the request to that upstream
+ proxyErr = h.reverseProxy(w, r, repl, dialInfo, next)
+ if proxyErr == nil || proxyErr == context.Canceled {
+ // context.Canceled happens when the downstream client
+ // cancels the request, which is not our failure
+ return true, nil
+ }
+
+ // if the roundtrip was successful, don't retry the request or
+ // ding the health status of the upstream (an error can still
+ // occur after the roundtrip if, for example, a response handler
+ // after the roundtrip returns an error)
+ if succ, ok := proxyErr.(roundtripSucceeded); ok {
+ return true, succ.error
+ }
+
+ // remember this failure (if enabled)
+ h.countFailure(upstream)
+
+ // if we've tried long enough, break
+ if !h.LoadBalancing.tryAgain(h.ctx, start, proxyErr, r) {
+ return true, proxyErr
+ }
+
+ return false, proxyErr
}
// prepareRequest clones req so that it can be safely modified without
@@ -651,9 +688,9 @@ func (h Handler) addForwardedHeaders(req *http.Request) error {
// (This method is mostly the beginning of what was borrowed from the net/http/httputil package in the
// Go standard library which was used as the foundation.)
func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, repl *caddy.Replacer, di DialInfo, next caddyhttp.Handler) error {
- _ = di.Upstream.Host.CountRequest(1)
+ _ = di.Upstream.Host.countRequest(1)
//nolint:errcheck
- defer di.Upstream.Host.CountRequest(-1)
+ defer di.Upstream.Host.countRequest(-1)
// point the request to this upstream
h.directRequest(req, di)
@@ -905,6 +942,35 @@ func (Handler) directRequest(req *http.Request, di DialInfo) {
req.URL.Host = reqHost
}
+func (h Handler) provisionUpstream(upstream *Upstream) {
+ // create or get the host representation for this upstream
+ upstream.fillHost()
+
+ // give it the circuit breaker, if any
+ upstream.cb = h.CB
+
+ // if the passive health checker has a non-zero UnhealthyRequestCount
+ // but the upstream has no MaxRequests set (they are the same thing,
+ // but the passive health checker is a default value for for upstreams
+ // without MaxRequests), copy the value into this upstream, since the
+ // value in the upstream (MaxRequests) is what is used during
+ // availability checks
+ if h.HealthChecks != nil && h.HealthChecks.Passive != nil {
+ h.HealthChecks.Passive.logger = h.logger.Named("health_checker.passive")
+ if h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
+ upstream.MaxRequests == 0 {
+ upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
+ }
+ }
+
+ // upstreams need independent access to the passive
+ // health check policy because passive health checks
+ // run without access to h.
+ if h.HealthChecks != nil {
+ upstream.healthCheckPolicy = h.HealthChecks.Passive
+ }
+}
+
// bufferedBody reads originalBody into a buffer, then returns a reader for the buffer.
// Always close the return value when done with it, just like if it was the original body!
func (h Handler) bufferedBody(originalBody io.ReadCloser) io.ReadCloser {
@@ -1085,6 +1151,20 @@ type Selector interface {
Select(UpstreamPool, *http.Request, http.ResponseWriter) *Upstream
}
+// UpstreamSource gets the list of upstreams that can be used when
+// proxying a request. Returned upstreams will be load balanced and
+// health-checked. This should be a very fast function -- instant
+// if possible -- and the return value must be as stable as possible.
+// In other words, the list of upstreams should ideally not change much
+// across successive calls. If the list of upstreams changes or the
+// ordering is not stable, load balancing will suffer. This function
+// may be called during each retry, multiple times per request, and as
+// such, needs to be instantaneous. The returned slice will not be
+// modified.
+type UpstreamSource interface {
+ GetUpstreams(*http.Request) ([]*Upstream, error)
+}
+
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
index c28799d..7175f77 100644
--- a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
@@ -22,9 +22,9 @@ import (
func testPool() UpstreamPool {
return UpstreamPool{
- {Host: new(upstreamHost)},
- {Host: new(upstreamHost)},
- {Host: new(upstreamHost)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
}
}
@@ -48,20 +48,20 @@ func TestRoundRobinPolicy(t *testing.T) {
t.Error("Expected third round robin host to be first host in the pool.")
}
// mark host as down
- pool[1].SetHealthy(false)
+ pool[1].setHealthy(false)
h = rrPolicy.Select(pool, req, nil)
if h != pool[2] {
t.Error("Expected to skip down host.")
}
// mark host as up
- pool[1].SetHealthy(true)
+ pool[1].setHealthy(true)
h = rrPolicy.Select(pool, req, nil)
if h == pool[2] {
t.Error("Expected to balance evenly among healthy hosts")
}
// mark host as full
- pool[1].CountRequest(1)
+ pool[1].countRequest(1)
pool[1].MaxRequests = 1
h = rrPolicy.Select(pool, req, nil)
if h != pool[2] {
@@ -74,13 +74,13 @@ func TestLeastConnPolicy(t *testing.T) {
lcPolicy := new(LeastConnSelection)
req, _ := http.NewRequest("GET", "/", nil)
- pool[0].CountRequest(10)
- pool[1].CountRequest(10)
+ pool[0].countRequest(10)
+ pool[1].countRequest(10)
h := lcPolicy.Select(pool, req, nil)
if h != pool[2] {
t.Error("Expected least connection host to be third host.")
}
- pool[2].CountRequest(100)
+ pool[2].countRequest(100)
h = lcPolicy.Select(pool, req, nil)
if h != pool[0] && h != pool[1] {
t.Error("Expected least connection host to be first or second host.")
@@ -139,7 +139,7 @@ func TestIPHashPolicy(t *testing.T) {
// we should get a healthy host if the original host is unhealthy and a
// healthy host is available
req.RemoteAddr = "172.0.0.1"
- pool[1].SetHealthy(false)
+ pool[1].setHealthy(false)
h = ipHash.Select(pool, req, nil)
if h != pool[2] {
t.Error("Expected ip hash policy host to be the third host.")
@@ -150,10 +150,10 @@ func TestIPHashPolicy(t *testing.T) {
if h != pool[2] {
t.Error("Expected ip hash policy host to be the third host.")
}
- pool[1].SetHealthy(true)
+ pool[1].setHealthy(true)
req.RemoteAddr = "172.0.0.3"
- pool[2].SetHealthy(false)
+ pool[2].setHealthy(false)
h = ipHash.Select(pool, req, nil)
if h != pool[0] {
t.Error("Expected ip hash policy host to be the first host.")
@@ -167,8 +167,8 @@ func TestIPHashPolicy(t *testing.T) {
// We should be able to resize the host pool and still be able to predict
// where a req will be routed with the same IP's used above
pool = UpstreamPool{
- {Host: new(upstreamHost)},
- {Host: new(upstreamHost)},
+ {Host: new(Host)},
+ {Host: new(Host)},
}
req.RemoteAddr = "172.0.0.1:80"
h = ipHash.Select(pool, req, nil)
@@ -192,8 +192,8 @@ func TestIPHashPolicy(t *testing.T) {
}
// We should get nil when there are no healthy hosts
- pool[0].SetHealthy(false)
- pool[1].SetHealthy(false)
+ pool[0].setHealthy(false)
+ pool[1].setHealthy(false)
h = ipHash.Select(pool, req, nil)
if h != nil {
t.Error("Expected ip hash policy host to be nil.")
@@ -201,25 +201,25 @@ func TestIPHashPolicy(t *testing.T) {
// Reproduce #4135
pool = UpstreamPool{
- {Host: new(upstreamHost)},
- {Host: new(upstreamHost)},
- {Host: new(upstreamHost)},
- {Host: new(upstreamHost)},
- {Host: new(upstreamHost)},
- {Host: new(upstreamHost)},
- {Host: new(upstreamHost)},
- {Host: new(upstreamHost)},
- {Host: new(upstreamHost)},
- }
- pool[0].SetHealthy(false)
- pool[1].SetHealthy(false)
- pool[2].SetHealthy(false)
- pool[3].SetHealthy(false)
- pool[4].SetHealthy(false)
- pool[5].SetHealthy(false)
- pool[6].SetHealthy(false)
- pool[7].SetHealthy(false)
- pool[8].SetHealthy(true)
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ {Host: new(Host)},
+ }
+ pool[0].setHealthy(false)
+ pool[1].setHealthy(false)
+ pool[2].setHealthy(false)
+ pool[3].setHealthy(false)
+ pool[4].setHealthy(false)
+ pool[5].setHealthy(false)
+ pool[6].setHealthy(false)
+ pool[7].setHealthy(false)
+ pool[8].setHealthy(true)
// We should get a result back when there is one healthy host left.
h = ipHash.Select(pool, req, nil)
@@ -239,7 +239,7 @@ func TestFirstPolicy(t *testing.T) {
t.Error("Expected first policy host to be the first host.")
}
- pool[0].SetHealthy(false)
+ pool[0].setHealthy(false)
h = firstPolicy.Select(pool, req, nil)
if h != pool[1] {
t.Error("Expected first policy host to be the second host.")
@@ -256,7 +256,7 @@ func TestURIHashPolicy(t *testing.T) {
t.Error("Expected uri policy host to be the first host.")
}
- pool[0].SetHealthy(false)
+ pool[0].setHealthy(false)
h = uriPolicy.Select(pool, request, nil)
if h != pool[1] {
t.Error("Expected uri policy host to be the first host.")
@@ -271,8 +271,8 @@ func TestURIHashPolicy(t *testing.T) {
// We should be able to resize the host pool and still be able to predict
// where a request will be routed with the same URI's used above
pool = UpstreamPool{
- {Host: new(upstreamHost)},
- {Host: new(upstreamHost)},
+ {Host: new(Host)},
+ {Host: new(Host)},
}
request = httptest.NewRequest(http.MethodGet, "/test", nil)
@@ -281,7 +281,7 @@ func TestURIHashPolicy(t *testing.T) {
t.Error("Expected uri policy host to be the first host.")
}
- pool[0].SetHealthy(false)
+ pool[0].setHealthy(false)
h = uriPolicy.Select(pool, request, nil)
if h != pool[1] {
t.Error("Expected uri policy host to be the first host.")
@@ -293,8 +293,8 @@ func TestURIHashPolicy(t *testing.T) {
t.Error("Expected uri policy host to be the second host.")
}
- pool[0].SetHealthy(false)
- pool[1].SetHealthy(false)
+ pool[0].setHealthy(false)
+ pool[1].setHealthy(false)
h = uriPolicy.Select(pool, request, nil)
if h != nil {
t.Error("Expected uri policy policy host to be nil.")
@@ -306,12 +306,12 @@ func TestLeastRequests(t *testing.T) {
pool[0].Dial = "localhost:8080"
pool[1].Dial = "localhost:8081"
pool[2].Dial = "localhost:8082"
- pool[0].SetHealthy(true)
- pool[1].SetHealthy(true)
- pool[2].SetHealthy(true)
- pool[0].CountRequest(10)
- pool[1].CountRequest(20)
- pool[2].CountRequest(30)
+ pool[0].setHealthy(true)
+ pool[1].setHealthy(true)
+ pool[2].setHealthy(true)
+ pool[0].countRequest(10)
+ pool[1].countRequest(20)
+ pool[2].countRequest(30)
result := leastRequests(pool)
@@ -329,12 +329,12 @@ func TestRandomChoicePolicy(t *testing.T) {
pool[0].Dial = "localhost:8080"
pool[1].Dial = "localhost:8081"
pool[2].Dial = "localhost:8082"
- pool[0].SetHealthy(false)
- pool[1].SetHealthy(true)
- pool[2].SetHealthy(true)
- pool[0].CountRequest(10)
- pool[1].CountRequest(20)
- pool[2].CountRequest(30)
+ pool[0].setHealthy(false)
+ pool[1].setHealthy(true)
+ pool[2].setHealthy(true)
+ pool[0].countRequest(10)
+ pool[1].countRequest(20)
+ pool[2].countRequest(30)
request := httptest.NewRequest(http.MethodGet, "/test", nil)
randomChoicePolicy := new(RandomChoiceSelection)
@@ -357,9 +357,9 @@ func TestCookieHashPolicy(t *testing.T) {
pool[0].Dial = "localhost:8080"
pool[1].Dial = "localhost:8081"
pool[2].Dial = "localhost:8082"
- pool[0].SetHealthy(true)
- pool[1].SetHealthy(false)
- pool[2].SetHealthy(false)
+ pool[0].setHealthy(true)
+ pool[1].setHealthy(false)
+ pool[2].setHealthy(false)
request := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
cookieHashPolicy := new(CookieHashSelection)
@@ -374,8 +374,8 @@ func TestCookieHashPolicy(t *testing.T) {
if h != pool[0] {
t.Error("Expected cookieHashPolicy host to be the first only available host.")
}
- pool[1].SetHealthy(true)
- pool[2].SetHealthy(true)
+ pool[1].setHealthy(true)
+ pool[2].setHealthy(true)
request = httptest.NewRequest(http.MethodGet, "/test", nil)
w = httptest.NewRecorder()
request.AddCookie(cookieServer1)
@@ -387,7 +387,7 @@ func TestCookieHashPolicy(t *testing.T) {
if len(s) != 0 {
t.Error("Expected cookieHashPolicy to not set a new cookie.")
}
- pool[0].SetHealthy(false)
+ pool[0].setHealthy(false)
request = httptest.NewRequest(http.MethodGet, "/test", nil)
w = httptest.NewRecorder()
request.AddCookie(cookieServer1)
diff --git a/modules/caddyhttp/reverseproxy/upstreams.go b/modules/caddyhttp/reverseproxy/upstreams.go
new file mode 100644
index 0000000..eb5845f
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/upstreams.go
@@ -0,0 +1,377 @@
+package reverseproxy
+
+import (
+ "context"
+ "fmt"
+ weakrand "math/rand"
+ "net"
+ "net/http"
+ "strconv"
+ "sync"
+ "time"
+
+ "github.com/caddyserver/caddy/v2"
+ "go.uber.org/zap"
+)
+
+func init() {
+ caddy.RegisterModule(SRVUpstreams{})
+ caddy.RegisterModule(AUpstreams{})
+}
+
+// SRVUpstreams provides upstreams from SRV lookups.
+// The lookup DNS name can be configured either by
+// its individual parts (that is, specifying the
+// service, protocol, and name separately) to form
+// the standard "_service._proto.name" domain, or
+// the domain can be specified directly in name by
+// leaving service and proto empty. See RFC 2782.
+//
+// Lookups are cached and refreshed at the configured
+// refresh interval.
+//
+// Returned upstreams are sorted by priority and weight.
+type SRVUpstreams struct {
+ // The service label.
+ Service string `json:"service,omitempty"`
+
+ // The protocol label; either tcp or udp.
+ Proto string `json:"proto,omitempty"`
+
+ // The name label; or, if service and proto are
+ // empty, the entire domain name to look up.
+ Name string `json:"name,omitempty"`
+
+ // The interval at which to refresh the SRV lookup.
+ // Results are cached between lookups. Default: 1m
+ Refresh caddy.Duration `json:"refresh,omitempty"`
+
+ // Configures the DNS resolver used to resolve the
+ // SRV address to SRV records.
+ Resolver *UpstreamResolver `json:"resolver,omitempty"`
+
+ // If Resolver is configured, how long to wait before
+ // timing out trying to connect to the DNS server.
+ DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`
+
+ // If Resolver is configured, how long to wait before
+ // spawning an RFC 6555 Fast Fallback connection.
+ // A negative value disables this.
+ FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"`
+
+ resolver *net.Resolver
+
+ logger *zap.Logger
+}
+
+// CaddyModule returns the Caddy module information.
+func (SRVUpstreams) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ ID: "http.reverse_proxy.upstreams.srv",
+ New: func() caddy.Module { return new(SRVUpstreams) },
+ }
+}
+
+// String returns the RFC 2782 representation of the SRV domain.
+func (su SRVUpstreams) String() string {
+ return fmt.Sprintf("_%s._%s.%s", su.Service, su.Proto, su.Name)
+}
+
+func (su *SRVUpstreams) Provision(ctx caddy.Context) error {
+ su.logger = ctx.Logger(su)
+ if su.Refresh == 0 {
+ su.Refresh = caddy.Duration(time.Minute)
+ }
+
+ if su.Resolver != nil {
+ err := su.Resolver.ParseAddresses()
+ if err != nil {
+ return err
+ }
+ d := &net.Dialer{
+ Timeout: time.Duration(su.DialTimeout),
+ FallbackDelay: time.Duration(su.FallbackDelay),
+ }
+ su.resolver = &net.Resolver{
+ PreferGo: true,
+ Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
+ //nolint:gosec
+ addr := su.Resolver.netAddrs[weakrand.Intn(len(su.Resolver.netAddrs))]
+ return d.DialContext(ctx, addr.Network, addr.JoinHostPort(0))
+ },
+ }
+ }
+ if su.resolver == nil {
+ su.resolver = net.DefaultResolver
+ }
+
+ return nil
+}
+
+func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
+ suStr := su.String()
+
+ // first, use a cheap read-lock to return a cached result quickly
+ srvsMu.RLock()
+ cached := srvs[suStr]
+ srvsMu.RUnlock()
+ if cached.isFresh() {
+ return cached.upstreams, nil
+ }
+
+ // otherwise, obtain a write-lock to update the cached value
+ srvsMu.Lock()
+ defer srvsMu.Unlock()
+
+ // check to see if it's still stale, since we're now in a different
+ // lock from when we first checked freshness; another goroutine might
+ // have refreshed it in the meantime before we re-obtained our lock
+ cached = srvs[suStr]
+ if cached.isFresh() {
+ return cached.upstreams, nil
+ }
+
+ // prepare parameters and perform the SRV lookup
+ repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
+ service := repl.ReplaceAll(su.Service, "")
+ proto := repl.ReplaceAll(su.Proto, "")
+ name := repl.ReplaceAll(su.Name, "")
+
+ su.logger.Debug("refreshing SRV upstreams",
+ zap.String("service", service),
+ zap.String("proto", proto),
+ zap.String("name", name))
+
+ _, records, err := su.resolver.LookupSRV(r.Context(), service, proto, name)
+ if err != nil {
+ // From LookupSRV docs: "If the response contains invalid names, those records are filtered
+ // out and an error will be returned alongside the the remaining results, if any." Thus, we
+ // only return an error if no records were also returned.
+ if len(records) == 0 {
+ return nil, err
+ }
+ su.logger.Warn("SRV records filtered", zap.Error(err))
+ }
+
+ upstreams := make([]*Upstream, len(records))
+ for i, rec := range records {
+ su.logger.Debug("discovered SRV record",
+ zap.String("target", rec.Target),
+ zap.Uint16("port", rec.Port),
+ zap.Uint16("priority", rec.Priority),
+ zap.Uint16("weight", rec.Weight))
+ addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port)))
+ upstreams[i] = &Upstream{Dial: addr}
+ }
+
+ // before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full
+ if cached.freshness.IsZero() && len(srvs) >= 100 {
+ for randomKey := range srvs {
+ delete(srvs, randomKey)
+ break
+ }
+ }
+
+ srvs[suStr] = srvLookup{
+ srvUpstreams: su,
+ freshness: time.Now(),
+ upstreams: upstreams,
+ }
+
+ return upstreams, nil
+}
+
+type srvLookup struct {
+ srvUpstreams SRVUpstreams
+ freshness time.Time
+ upstreams []*Upstream
+}
+
+func (sl srvLookup) isFresh() bool {
+ return time.Since(sl.freshness) < time.Duration(sl.srvUpstreams.Refresh)
+}
+
+var (
+ srvs = make(map[string]srvLookup)
+ srvsMu sync.RWMutex
+)
+
+// AUpstreams provides upstreams from A/AAAA lookups.
+// Results are cached and refreshed at the configured
+// refresh interval.
+type AUpstreams struct {
+ // The domain name to look up.
+ Name string `json:"name,omitempty"`
+
+ // The port to use with the upstreams. Default: 80
+ Port string `json:"port,omitempty"`
+
+ // The interval at which to refresh the A lookup.
+ // Results are cached between lookups. Default: 1m
+ Refresh caddy.Duration `json:"refresh,omitempty"`
+
+ // Configures the DNS resolver used to resolve the
+ // domain name to A records.
+ Resolver *UpstreamResolver `json:"resolver,omitempty"`
+
+ // If Resolver is configured, how long to wait before
+ // timing out trying to connect to the DNS server.
+ DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`
+
+ // If Resolver is configured, how long to wait before
+ // spawning an RFC 6555 Fast Fallback connection.
+ // A negative value disables this.
+ FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"`
+
+ resolver *net.Resolver
+}
+
+// CaddyModule returns the Caddy module information.
+func (AUpstreams) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ ID: "http.reverse_proxy.upstreams.a",
+ New: func() caddy.Module { return new(AUpstreams) },
+ }
+}
+
+func (au AUpstreams) String() string { return au.Name }
+
+func (au *AUpstreams) Provision(_ caddy.Context) error {
+ if au.Refresh == 0 {
+ au.Refresh = caddy.Duration(time.Minute)
+ }
+ if au.Port == "" {
+ au.Port = "80"
+ }
+
+ if au.Resolver != nil {
+ err := au.Resolver.ParseAddresses()
+ if err != nil {
+ return err
+ }
+ d := &net.Dialer{
+ Timeout: time.Duration(au.DialTimeout),
+ FallbackDelay: time.Duration(au.FallbackDelay),
+ }
+ au.resolver = &net.Resolver{
+ PreferGo: true,
+ Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
+ //nolint:gosec
+ addr := au.Resolver.netAddrs[weakrand.Intn(len(au.Resolver.netAddrs))]
+ return d.DialContext(ctx, addr.Network, addr.JoinHostPort(0))
+ },
+ }
+ }
+ if au.resolver == nil {
+ au.resolver = net.DefaultResolver
+ }
+
+ return nil
+}
+
+func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
+ auStr := au.String()
+
+ // first, use a cheap read-lock to return a cached result quickly
+ aAaaaMu.RLock()
+ cached := aAaaa[auStr]
+ aAaaaMu.RUnlock()
+ if cached.isFresh() {
+ return cached.upstreams, nil
+ }
+
+ // otherwise, obtain a write-lock to update the cached value
+ aAaaaMu.Lock()
+ defer aAaaaMu.Unlock()
+
+ // check to see if it's still stale, since we're now in a different
+ // lock from when we first checked freshness; another goroutine might
+ // have refreshed it in the meantime before we re-obtained our lock
+ cached = aAaaa[auStr]
+ if cached.isFresh() {
+ return cached.upstreams, nil
+ }
+
+ repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
+ name := repl.ReplaceAll(au.Name, "")
+ port := repl.ReplaceAll(au.Port, "")
+
+ ips, err := au.resolver.LookupIPAddr(r.Context(), name)
+ if err != nil {
+ return nil, err
+ }
+
+ upstreams := make([]*Upstream, len(ips))
+ for i, ip := range ips {
+ upstreams[i] = &Upstream{
+ Dial: net.JoinHostPort(ip.String(), port),
+ }
+ }
+
+ // before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full
+ if cached.freshness.IsZero() && len(srvs) >= 100 {
+ for randomKey := range aAaaa {
+ delete(aAaaa, randomKey)
+ break
+ }
+ }
+
+ aAaaa[auStr] = aLookup{
+ aUpstreams: au,
+ freshness: time.Now(),
+ upstreams: upstreams,
+ }
+
+ return upstreams, nil
+}
+
+type aLookup struct {
+ aUpstreams AUpstreams
+ freshness time.Time
+ upstreams []*Upstream
+}
+
+func (al aLookup) isFresh() bool {
+ return time.Since(al.freshness) < time.Duration(al.aUpstreams.Refresh)
+}
+
+// UpstreamResolver holds the set of addresses of DNS resolvers of
+// upstream addresses
+type UpstreamResolver struct {
+ // The addresses of DNS resolvers to use when looking up the addresses of proxy upstreams.
+ // It accepts [network addresses](/docs/conventions#network-addresses)
+ // with port range of only 1. If the host is an IP address, it will be dialed directly to resolve the upstream server.
+ // If the host is not an IP address, the addresses are resolved using the [name resolution convention](https://golang.org/pkg/net/#hdr-Name_Resolution) of the Go standard library.
+ // If the array contains more than 1 resolver address, one is chosen at random.
+ Addresses []string `json:"addresses,omitempty"`
+ netAddrs []caddy.NetworkAddress
+}
+
+// ParseAddresses parses all the configured network addresses
+// and ensures they're ready to be used.
+func (u *UpstreamResolver) ParseAddresses() error {
+ for _, v := range u.Addresses {
+ addr, err := caddy.ParseNetworkAddress(v)
+ if err != nil {
+ return err
+ }
+ if addr.PortRangeSize() != 1 {
+ return fmt.Errorf("resolver address must have exactly one address; cannot call %v", addr)
+ }
+ u.netAddrs = append(u.netAddrs, addr)
+ }
+ return nil
+}
+
+var (
+ aAaaa = make(map[string]aLookup)
+ aAaaaMu sync.RWMutex
+)
+
+// Interface guards
+var (
+ _ caddy.Provisioner = (*SRVUpstreams)(nil)
+ _ UpstreamSource = (*SRVUpstreams)(nil)
+ _ caddy.Provisioner = (*AUpstreams)(nil)
+ _ UpstreamSource = (*AUpstreams)(nil)
+)