From 0830fbad0347ead1dbea60e664556b263e44653f Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Thu, 5 Sep 2019 13:14:39 -0600 Subject: Reconcile upstream dial addresses and request host/URL information My goodness that was complicated Blessed be request.Context Sort of --- listeners.go | 18 +-- listeners_test.go | 22 +++- modules/caddyhttp/caddyhttp.go | 12 +- modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go | 102 ++++++----------- modules/caddyhttp/reverseproxy/healthchecks.go | 68 +++++++---- modules/caddyhttp/reverseproxy/hosts.go | 38 +++++-- modules/caddyhttp/reverseproxy/httptransport.go | 28 ++--- modules/caddyhttp/reverseproxy/reverseproxy.go | 130 ++++++++++++++-------- modules/caddyhttp/server.go | 2 +- 9 files changed, 237 insertions(+), 183 deletions(-) diff --git a/listeners.go b/listeners.go index 3c3a704..846059d 100644 --- a/listeners.go +++ b/listeners.go @@ -165,19 +165,19 @@ var ( listenersMu sync.Mutex ) -// ParseListenAddr parses addr, a string of the form "network/host:port" +// ParseNetworkAddress parses addr, a string of the form "network/host:port" // (with any part optional) into its component parts. Because a port can // also be a port range, there may be multiple addresses returned. -func ParseListenAddr(addr string) (network string, addrs []string, err error) { +func ParseNetworkAddress(addr string) (network string, addrs []string, err error) { var host, port string - network, host, port, err = SplitListenAddr(addr) + network, host, port, err = SplitNetworkAddress(addr) if network == "" { network = "tcp" } if err != nil { return } - if network == "unix" { + if network == "unix" || network == "unixgram" || network == "unixpacket" { addrs = []string{host} return } @@ -204,14 +204,14 @@ func ParseListenAddr(addr string) (network string, addrs []string, err error) { return } -// SplitListenAddr splits a into its network, host, and port components. +// SplitNetworkAddress splits a into its network, host, and port components. // Note that port may be a port range, or omitted for unix sockets. -func SplitListenAddr(a string) (network, host, port string, err error) { +func SplitNetworkAddress(a string) (network, host, port string, err error) { if idx := strings.Index(a, "/"); idx >= 0 { network = strings.ToLower(strings.TrimSpace(a[:idx])) a = a[idx+1:] } - if network == "unix" { + if network == "unix" || network == "unixgram" || network == "unixpacket" { host = a return } @@ -219,11 +219,11 @@ func SplitListenAddr(a string) (network, host, port string, err error) { return } -// JoinListenAddr combines network, host, and port into a single +// JoinNetworkAddress combines network, host, and port into a single // address string of the form "network/host:port". Port may be a // port range. For unix sockets, the network should be "unix" and // the path to the socket should be given in the host argument. -func JoinListenAddr(network, host, port string) string { +func JoinNetworkAddress(network, host, port string) string { var a string if network != "" { a = network + "/" diff --git a/listeners_test.go b/listeners_test.go index 7c5a2fc..11d3980 100644 --- a/listeners_test.go +++ b/listeners_test.go @@ -19,7 +19,7 @@ import ( "testing" ) -func TestSplitListenerAddr(t *testing.T) { +func TestSplitNetworkAddress(t *testing.T) { for i, tc := range []struct { input string expectNetwork string @@ -67,8 +67,18 @@ func TestSplitListenerAddr(t *testing.T) { expectNetwork: "unix", expectHost: "/foo/bar", }, + { + input: "unixgram//foo/bar", + expectNetwork: "unixgram", + expectHost: "/foo/bar", + }, + { + input: "unixpacket//foo/bar", + expectNetwork: "unixpacket", + expectHost: "/foo/bar", + }, } { - actualNetwork, actualHost, actualPort, err := SplitListenAddr(tc.input) + actualNetwork, actualHost, actualPort, err := SplitNetworkAddress(tc.input) if tc.expectErr && err == nil { t.Errorf("Test %d: Expected error but got: %v", i, err) } @@ -87,7 +97,7 @@ func TestSplitListenerAddr(t *testing.T) { } } -func TestJoinListenerAddr(t *testing.T) { +func TestJoinNetworkAddress(t *testing.T) { for i, tc := range []struct { network, host, port string expect string @@ -129,14 +139,14 @@ func TestJoinListenerAddr(t *testing.T) { expect: "unix//foo/bar", }, } { - actual := JoinListenAddr(tc.network, tc.host, tc.port) + actual := JoinNetworkAddress(tc.network, tc.host, tc.port) if actual != tc.expect { t.Errorf("Test %d: Expected '%s' but got '%s'", i, tc.expect, actual) } } } -func TestParseListenerAddr(t *testing.T) { +func TestParseNetworkAddress(t *testing.T) { for i, tc := range []struct { input string expectNetwork string @@ -194,7 +204,7 @@ func TestParseListenerAddr(t *testing.T) { expectAddrs: []string{"localhost:0"}, }, } { - actualNetwork, actualAddrs, err := ParseListenAddr(tc.input) + actualNetwork, actualAddrs, err := ParseNetworkAddress(tc.input) if tc.expectErr && err == nil { t.Errorf("Test %d: Expected error but got: %v", i, err) } diff --git a/modules/caddyhttp/caddyhttp.go b/modules/caddyhttp/caddyhttp.go index 300b5fd..174e316 100644 --- a/modules/caddyhttp/caddyhttp.go +++ b/modules/caddyhttp/caddyhttp.go @@ -108,7 +108,7 @@ func (app *App) Validate() error { lnAddrs := make(map[string]string) for srvName, srv := range app.Servers { for _, addr := range srv.Listen { - netw, expanded, err := caddy.ParseListenAddr(addr) + netw, expanded, err := caddy.ParseNetworkAddress(addr) if err != nil { return fmt.Errorf("invalid listener address '%s': %v", addr, err) } @@ -149,7 +149,7 @@ func (app *App) Start() error { } for _, lnAddr := range srv.Listen { - network, addrs, err := caddy.ParseListenAddr(lnAddr) + network, addrs, err := caddy.ParseNetworkAddress(lnAddr) if err != nil { return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err) } @@ -309,7 +309,7 @@ func (app *App) automaticHTTPS() error { // create HTTP->HTTPS redirects for _, addr := range srv.Listen { - netw, host, port, err := caddy.SplitListenAddr(addr) + netw, host, port, err := caddy.SplitNetworkAddress(addr) if err != nil { return fmt.Errorf("%s: invalid listener address: %v", srvName, addr) } @@ -318,7 +318,7 @@ func (app *App) automaticHTTPS() error { if httpPort == 0 { httpPort = DefaultHTTPPort } - httpRedirLnAddr := caddy.JoinListenAddr(netw, host, strconv.Itoa(httpPort)) + httpRedirLnAddr := caddy.JoinNetworkAddress(netw, host, strconv.Itoa(httpPort)) lnAddrMap[httpRedirLnAddr] = struct{}{} if parts := strings.SplitN(port, "-", 2); len(parts) == 2 { @@ -361,7 +361,7 @@ func (app *App) automaticHTTPS() error { var lnAddrs []string mapLoop: for addr := range lnAddrMap { - netw, addrs, err := caddy.ParseListenAddr(addr) + netw, addrs, err := caddy.ParseNetworkAddress(addr) if err != nil { continue } @@ -386,7 +386,7 @@ func (app *App) automaticHTTPS() error { func (app *App) listenerTaken(network, address string) bool { for _, srv := range app.Servers { for _, addr := range srv.Listen { - netw, addrs, err := caddy.ParseListenAddr(addr) + netw, addrs, err := caddy.ParseNetworkAddress(addr) if err != nil || netw != network { continue } diff --git a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go index 32f094b..35fef5f 100644 --- a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go +++ b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go @@ -25,6 +25,7 @@ import ( "strings" "time" + "github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy" "github.com/caddyserver/caddy/v2/modules/caddytls" "github.com/caddyserver/caddy/v2" @@ -34,6 +35,7 @@ func init() { caddy.RegisterModule(Transport{}) } +// Transport facilitates FastCGI communication. type Transport struct { ////////////////////////////// // TODO: taken from v1 Handler type @@ -57,32 +59,32 @@ type Transport struct { // Use this directory as the fastcgi root directory. Defaults to the root // directory of the parent virtual host. - Root string + Root string `json:"root,omitempty"` // The path in the URL will be split into two, with the first piece ending // with the value of SplitPath. The first piece will be assumed as the // actual resource (CGI script) name, and the second piece will be set to // PATH_INFO for the CGI script to use. - SplitPath string + SplitPath string `json:"split_path,omitempty"` // If the URL ends with '/' (which indicates a directory), these index // files will be tried instead. - IndexFiles []string + // IndexFiles []string // Environment Variables - EnvVars [][2]string + EnvVars [][2]string `json:"env,omitempty"` // Ignored paths - IgnoredSubPaths []string + // IgnoredSubPaths []string // The duration used to set a deadline when connecting to an upstream. - DialTimeout time.Duration + DialTimeout caddy.Duration `json:"dial_timeout,omitempty"` // The duration used to set a deadline when reading from the FastCGI server. - ReadTimeout time.Duration + ReadTimeout caddy.Duration `json:"read_timeout,omitempty"` // The duration used to set a deadline when sending to the FastCGI server. - WriteTimeout time.Duration + WriteTimeout caddy.Duration `json:"write_timeout,omitempty"` } // CaddyModule returns the Caddy module information. @@ -93,102 +95,62 @@ func (Transport) CaddyModule() caddy.ModuleInfo { } } +// RoundTrip implements http.RoundTripper. func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) { - // Create environment for CGI script env, err := t.buildEnv(r) if err != nil { return nil, fmt.Errorf("building environment: %v", err) } - // TODO: - // Connect to FastCGI gateway - // address, err := f.Address() - // if err != nil { - // return http.StatusBadGateway, err - // } - // network, address := parseAddress(address) - network, address := "tcp", r.URL.Host // TODO: - + // TODO: doesn't dialer have a Timeout field? ctx := context.Background() if t.DialTimeout > 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, t.DialTimeout) + ctx, cancel = context.WithTimeout(ctx, time.Duration(t.DialTimeout)) defer cancel() } + // extract dial information from request (this + // should embedded by the reverse proxy) + network, address := "tcp", r.URL.Host + if dialInfoVal := ctx.Value(reverseproxy.DialInfoCtxKey); dialInfoVal != nil { + dialInfo := dialInfoVal.(reverseproxy.DialInfo) + network = dialInfo.Network + address = dialInfo.Address + } + fcgiBackend, err := DialContext(ctx, network, address) if err != nil { return nil, fmt.Errorf("dialing backend: %v", err) } - // fcgiBackend is closed when response body is closed (see clientCloser) + // fcgiBackend gets closed when response body is closed (see clientCloser) // read/write timeouts - if err := fcgiBackend.SetReadTimeout(t.ReadTimeout); err != nil { + if err := fcgiBackend.SetReadTimeout(time.Duration(t.ReadTimeout)); err != nil { return nil, fmt.Errorf("setting read timeout: %v", err) } - if err := fcgiBackend.SetWriteTimeout(t.WriteTimeout); err != nil { + if err := fcgiBackend.SetWriteTimeout(time.Duration(t.WriteTimeout)); err != nil { return nil, fmt.Errorf("setting write timeout: %v", err) } - var resp *http.Response - - var contentLength int64 - // if ContentLength is already set - if r.ContentLength > 0 { - contentLength = r.ContentLength - } else { + contentLength := r.ContentLength + if contentLength == 0 { contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64) } + + var resp *http.Response switch r.Method { - case "HEAD": + case http.MethodHead: resp, err = fcgiBackend.Head(env) - case "GET": + case http.MethodGet: resp, err = fcgiBackend.Get(env, r.Body, contentLength) - case "OPTIONS": + case http.MethodOptions: resp, err = fcgiBackend.Options(env) default: resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength) } - // TODO: return resp, err - - // Stuff brought over from v1 that might not be necessary here: - - // if resp != nil && resp.Body != nil { - // defer resp.Body.Close() - // } - - // if err != nil { - // if err, ok := err.(net.Error); ok && err.Timeout() { - // return http.StatusGatewayTimeout, err - // } else if err != io.EOF { - // return http.StatusBadGateway, err - // } - // } - - // // Write response header - // writeHeader(w, resp) - - // // Write the response body - // _, err = io.Copy(w, resp.Body) - // if err != nil { - // return http.StatusBadGateway, err - // } - - // // Log any stderr output from upstream - // if fcgiBackend.stderr.Len() != 0 { - // // Remove trailing newline, error logger already does this. - // err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n")) - // } - - // // Normally we would return the status code if it is an error status (>= 400), - // // however, upstream FastCGI apps don't know about our contract and have - // // probably already written an error page. So we just return 0, indicating - // // that the response body is already written. However, we do return any - // // error value so it can be logged. - // // Note that the proxy middleware works the same way, returning status=0. - // return 0, err } // buildEnv returns a set of CGI environment variables for the request. diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go index 673f7c4..abe0f9c 100644 --- a/modules/caddyhttp/reverseproxy/healthchecks.go +++ b/modules/caddyhttp/reverseproxy/healthchecks.go @@ -15,6 +15,7 @@ package reverseproxy import ( + "context" "fmt" "io" "io/ioutil" @@ -93,15 +94,31 @@ func (h *Handler) activeHealthChecker() { // health checks for all hosts in the global repository. func (h *Handler) doActiveHealthChecksForAllHosts() { hosts.Range(func(key, value interface{}) bool { - addr := key.(string) + networkAddr := key.(string) host := value.(Host) - go func(addr string, host Host) { - err := h.doActiveHealthCheck(addr, host) + go func(networkAddr string, host Host) { + network, addrs, err := caddy.ParseNetworkAddress(networkAddr) if err != nil { - log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", addr, err) + log.Printf("[ERROR] reverse_proxy: active health check for host %s: bad network address: %v", networkAddr, err) + return } - }(addr, host) + if len(addrs) != 1 { + log.Printf("[ERROR] reverse_proxy: active health check for host %s: multiple addresses (upstream must map to only one address)", networkAddr) + return + } + hostAddr := addrs[0] + if network == "unix" || network == "unixgram" || network == "unixpacket" { + // this will be used as the Host portion of a http.Request URL, and + // paths to socket files would produce an error when creating URL, + // so use a fake Host value instead + hostAddr = network + } + err = h.doActiveHealthCheck(DialInfo{network, addrs[0]}, hostAddr, host) + if err != nil { + log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", networkAddr, err) + } + }(networkAddr, host) // continue to iterate all hosts return true @@ -115,26 +132,39 @@ func (h *Handler) doActiveHealthChecksForAllHosts() { // according to whether it passes the health check. An error is // returned only if the health check fails to occur or if marking // the host's health status fails. -func (h *Handler) doActiveHealthCheck(hostAddr string, host Host) error { - // create the URL for the health check - u, err := url.Parse(hostAddr) - if err != nil { - return err +func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host Host) error { + // create the URL for the request that acts as a health check + scheme := "http" + if ht, ok := h.Transport.(*http.Transport); ok && ht.TLSClientConfig != nil { + // this is kind of a hacky way to know if we should use HTTPS, but whatever + scheme = "https" } - if h.HealthChecks.Active.Path != "" { - u.Path = h.HealthChecks.Active.Path + u := &url.URL{ + Scheme: scheme, + Host: hostAddr, + Path: h.HealthChecks.Active.Path, } + + // adjust the port, if configured to be different if h.HealthChecks.Active.Port != 0 { portStr := strconv.Itoa(h.HealthChecks.Active.Port) - u.Host = net.JoinHostPort(u.Hostname(), portStr) + host, _, err := net.SplitHostPort(hostAddr) + if err != nil { + host = hostAddr + } + u.Host = net.JoinHostPort(host, portStr) } - req, err := http.NewRequest(http.MethodGet, u.String(), nil) + // attach dialing information to this request + ctx := context.Background() + ctx = context.WithValue(ctx, caddy.ReplacerCtxKey, caddy.NewReplacer()) + ctx = context.WithValue(ctx, DialInfoCtxKey, dialInfo) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) if err != nil { - return err + return fmt.Errorf("making request: %v", err) } - // do the request, careful to tame the response body + // do the request, being careful to tame the response body resp, err := h.HealthChecks.Active.httpClient.Do(req) if err != nil { log.Printf("[INFO] reverse_proxy: active health check: %s is down (HTTP request failed: %v)", hostAddr, err) @@ -149,7 +179,7 @@ func (h *Handler) doActiveHealthCheck(hostAddr string, host Host) error { body = io.LimitReader(body, h.HealthChecks.Active.MaxSize) } defer func() { - // drain any remaining body so connection can be re-used + // drain any remaining body so connection could be re-used io.Copy(ioutil.Discard, body) resp.Body.Close() }() @@ -225,7 +255,7 @@ func (h *Handler) countFailure(upstream *Upstream) { err := upstream.Host.CountFail(1) if err != nil { log.Printf("[ERROR] proxy: upstream %s: counting failure: %v", - upstream.hostURL, err) + upstream.dialInfo, err) } // forget it later @@ -234,7 +264,7 @@ func (h *Handler) countFailure(upstream *Upstream) { err := host.CountFail(-1) if err != nil { log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v", - upstream.hostURL, err) + upstream.dialInfo, err) } }(upstream.Host, failDuration) } diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go index b40e614..ad27625 100644 --- a/modules/caddyhttp/reverseproxy/hosts.go +++ b/modules/caddyhttp/reverseproxy/hosts.go @@ -16,7 +16,6 @@ package reverseproxy import ( "fmt" - "net/url" "sync/atomic" "github.com/caddyserver/caddy/v2" @@ -59,7 +58,7 @@ type UpstreamPool []*Upstream type Upstream struct { Host `json:"-"` - Address string `json:"address,omitempty"` + Dial string `json:"dial,omitempty"` MaxRequests int `json:"max_requests,omitempty"` // TODO: This could be really useful, to bind requests @@ -68,8 +67,8 @@ type Upstream struct { // IPAffinity string healthCheckPolicy *PassiveHealthChecks - hostURL *url.URL cb CircuitBreaker + dialInfo DialInfo } // Available returns true if the remote host @@ -101,11 +100,6 @@ func (u *Upstream) Full() bool { return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests } -// URL returns the upstream host's endpoint URL. -func (u *Upstream) URL() *url.URL { - return u.hostURL -} - // upstreamHost is the basic, in-memory representation // of the state of a remote host. It implements the // Host interface. @@ -162,6 +156,34 @@ func (uh *upstreamHost) SetHealthy(healthy bool) (bool, error) { return swapped, nil } +// DialInfo contains information needed to dial a +// connection to an upstream host. This information +// may be different than that which is represented +// in a URL (for example, unix sockets don't have +// a host that can be represented in a URL, but +// they certainly have a network name and address). +type DialInfo struct { + // The network to use. This should be one of the + // values that is accepted by net.Dial: + // https://golang.org/pkg/net/#Dial + Network string + + // The address to dial. Follows the same + // semantics and rules as net.Dial. + Address string +} + +// String returns the Caddy network address form +// by joining the network and address with a +// forward slash. +func (di DialInfo) String() string { + return di.Network + "/" + di.Address +} + +// DialInfoCtxKey is used to store a DialInfo +// in a context.Context. +const DialInfoCtxKey = caddy.CtxKey("dial_info") + // hosts is the global repository for hosts that are // currently in use by active configuration(s). This // allows the state of remote hosts to be preserved diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index d9dc457..6c1d9c8 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -15,6 +15,7 @@ package reverseproxy import ( + "context" "crypto/tls" "crypto/x509" "encoding/base64" @@ -63,14 +64,23 @@ func (HTTPTransport) CaddyModule() caddy.ModuleInfo { // Provision sets up h.RoundTripper with a http.Transport // that is ready to use. -func (h *HTTPTransport) Provision(ctx caddy.Context) error { +func (h *HTTPTransport) Provision(_ caddy.Context) error { dialer := &net.Dialer{ Timeout: time.Duration(h.DialTimeout), FallbackDelay: time.Duration(h.FallbackDelay), // TODO: Resolver } + rt := &http.Transport{ - DialContext: dialer.DialContext, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + // the proper dialing information should be embedded into the request's context + if dialInfoVal := ctx.Value(DialInfoCtxKey); dialInfoVal != nil { + dialInfo := dialInfoVal.(DialInfo) + network = dialInfo.Network + address = dialInfo.Address + } + return dialer.DialContext(ctx, network, address) + }, MaxConnsPerHost: h.MaxConnsPerHost, ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout), ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout), @@ -91,7 +101,6 @@ func (h *HTTPTransport) Provision(ctx caddy.Context) error { if h.KeepAlive != nil { dialer.KeepAlive = time.Duration(h.KeepAlive.ProbeInterval) - if enabled := h.KeepAlive.Enabled; enabled != nil { rt.DisableKeepAlives = !*enabled } @@ -191,16 +200,3 @@ type KeepAlive struct { MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"` IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle } - -var ( - defaultDialer = net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, - } - - defaultTransport = &http.Transport{ - DialContext: defaultDialer.DialContext, - TLSHandshakeTimeout: 5 * time.Second, - IdleConnTimeout: 2 * time.Minute, - } -) diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 7bf9a2f..5a37613 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -20,7 +20,6 @@ import ( "fmt" "net" "net/http" - "net/url" "regexp" "strings" "time" @@ -86,7 +85,18 @@ func (h *Handler) Provision(ctx caddy.Context) error { } if h.Transport == nil { - h.Transport = defaultTransport + t := &HTTPTransport{ + KeepAlive: &KeepAlive{ + ProbeInterval: caddy.Duration(30 * time.Second), + IdleConnTimeout: caddy.Duration(2 * time.Minute), + }, + DialTimeout: caddy.Duration(10 * time.Second), + } + err := t.Provision(ctx) + if err != nil { + return fmt.Errorf("provisioning default transport: %v", err) + } + h.Transport = t } if h.LoadBalancing == nil { @@ -133,51 +143,65 @@ func (h *Handler) Provision(ctx caddy.Context) error { go h.activeHealthChecker() } + var allUpstreams []*Upstream for _, upstream := range h.Upstreams { - upstream.cb = h.CB - - // url parser requires a scheme - if !strings.Contains(upstream.Address, "://") { - upstream.Address = "http://" + upstream.Address - } - u, err := url.Parse(upstream.Address) + // upstreams are allowed to map to only a single host, + // but an upstream's address may semantically represent + // multiple addresses, so make sure to handle each + // one in turn based on this one upstream config + network, addresses, err := caddy.ParseNetworkAddress(upstream.Dial) if err != nil { - return fmt.Errorf("invalid upstream address %s: %v", upstream.Address, err) - } - upstream.hostURL = u - - // if host already exists from a current config, - // use that instead; otherwise, add it - // TODO: make hosts modular, so that their state can be distributed in enterprise for example - // TODO: If distributed, the pool should be stored in storage... - var host Host = new(upstreamHost) - activeHost, loaded := hosts.LoadOrStore(u.String(), host) - if loaded { - host = activeHost.(Host) - } - upstream.Host = host - - // if the passive health checker has a non-zero "unhealthy - // request count" but the upstream has no MaxRequests set - // (they are the same thing, but one is a default value for - // for upstreams with a zero MaxRequests), copy the default - // value into this upstream, since the value in the upstream - // is what is used during availability checks - if h.HealthChecks != nil && - h.HealthChecks.Passive != nil && - h.HealthChecks.Passive.UnhealthyRequestCount > 0 && - upstream.MaxRequests == 0 { - upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount + return fmt.Errorf("parsing dial address: %v", err) } - if h.HealthChecks != nil { + for _, addr := range addresses { + // make a new upstream based on the original + // that has a singular dial address + upstreamCopy := *upstream + upstreamCopy.dialInfo = DialInfo{network, addr} + upstreamCopy.Dial = upstreamCopy.dialInfo.String() + upstreamCopy.cb = h.CB + + // if host already exists from a current config, + // use that instead; otherwise, add it + // TODO: make hosts modular, so that their state can be distributed in enterprise for example + // TODO: If distributed, the pool should be stored in storage... + var host Host = new(upstreamHost) + activeHost, loaded := hosts.LoadOrStore(upstreamCopy.Dial, host) + if loaded { + host = activeHost.(Host) + } + upstreamCopy.Host = host + + // if the passive health checker has a non-zero "unhealthy + // request count" but the upstream has no MaxRequests set + // (they are the same thing, but one is a default value for + // for upstreams with a zero MaxRequests), copy the default + // value into this upstream, since the value in the upstream + // (MaxRequests) is what is used during availability checks + if h.HealthChecks != nil && + h.HealthChecks.Passive != nil && + h.HealthChecks.Passive.UnhealthyRequestCount > 0 && + upstreamCopy.MaxRequests == 0 { + upstreamCopy.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount + } + // upstreams need independent access to the passive - // health check policy so they can, you know, passively - // do health checks - upstream.healthCheckPolicy = h.HealthChecks.Passive + // health check policy because they run outside of the + // scope of a request handler + if h.HealthChecks != nil { + upstreamCopy.healthCheckPolicy = h.HealthChecks.Passive + } + + allUpstreams = append(allUpstreams, &upstreamCopy) } } + // replace the unmarshaled upstreams (possible 1:many + // address mapping) with our list, which is mapped 1:1, + // thus may have expanded the original list + h.Upstreams = allUpstreams + return nil } @@ -192,7 +216,7 @@ func (h *Handler) Cleanup() error { // remove hosts from our config from the pool for _, upstream := range h.Upstreams { - hosts.Delete(upstream.hostURL.String()) + hosts.Delete(upstream.dialInfo.String()) } return nil @@ -222,6 +246,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht continue } + // attach to the request information about how to dial the upstream; + // this is necessary because the information cannot be sufficiently + // or satisfactorily represented in a URL + ctx := context.WithValue(r.Context(), DialInfoCtxKey, upstream.dialInfo) + r = r.WithContext(ctx) + // proxy the request to that upstream proxyErr = h.reverseProxy(w, r, upstream) if proxyErr == nil || proxyErr == context.Canceled { @@ -249,6 +279,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht // This assumes that no mutations of the request are performed // by h during or after proxying. func (h Handler) prepareRequest(req *http.Request) error { + // as a special (but very common) case, if the transport + // is HTTP, then ensure the request has the proper scheme + // because incoming requests by default are lacking it + if req.URL.Scheme == "" { + req.URL.Scheme = "http" + if ht, ok := h.Transport.(*HTTPTransport); ok && ht.TLS != nil { + req.URL.Scheme = "https" + } + } + if req.ContentLength == 0 { req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries } @@ -433,14 +473,8 @@ func (h Handler) tryAgain(start time.Time, proxyErr error) bool { // directRequest modifies only req.URL so that it points to the // given upstream host. It must modify ONLY the request URL. func (h Handler) directRequest(req *http.Request, upstream *Upstream) { - target := upstream.hostURL - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host - req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) // TODO: This might be a bug (if any part of the path was augmented from a previously-tried upstream; need to start from clean original path of request, same for query string!) - if target.RawQuery == "" || req.URL.RawQuery == "" { - req.URL.RawQuery = target.RawQuery + req.URL.RawQuery - } else { - req.URL.RawQuery = target.RawQuery + "&" + req.URL.RawQuery + if req.URL.Host == "" { + req.URL.Host = upstream.dialInfo.Address } } diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index f820f71..248e5f2 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -168,7 +168,7 @@ func (s *Server) enforcementHandler(w http.ResponseWriter, r *http.Request, next // listeners in s that use a port which is not otherPort. func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool { for _, lnAddr := range s.Listen { - _, addrs, err := caddy.ParseListenAddr(lnAddr) + _, addrs, err := caddy.ParseNetworkAddress(lnAddr) if err == nil { for _, a := range addrs { _, port, err := net.SplitHostPort(a) -- cgit v1.2.3