diff options
Diffstat (limited to 'modules/caddyhttp/reverseproxy/reverseproxy.go')
-rw-r--r--[-rwxr-xr-x] | modules/caddyhttp/reverseproxy/reverseproxy.go | 784 |
1 files changed, 401 insertions, 383 deletions
diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 68393de..5a37613 100755..100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -16,227 +16,304 @@ package reverseproxy import ( "context" + "encoding/json" "fmt" - "io" - "log" "net" "net/http" - "net/url" + "regexp" "strings" - "sync" "time" + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" "golang.org/x/net/http/httpguts" ) -// ReverseProxy is an HTTP Handler that takes an incoming request and -// sends it to another server, proxying the response back to the -// client. -type ReverseProxy struct { - // Director must be a function which modifies - // the request into a new request to be sent - // using Transport. Its response is then copied - // back to the original client unmodified. - // Director must not access the provided Request - // after returning. - Director func(*http.Request) - - // The transport used to perform proxy requests. - // If nil, http.DefaultTransport is used. - Transport http.RoundTripper - - // FlushInterval specifies the flush interval - // to flush to the client while copying the - // response body. - // If zero, no periodic flushing is done. - // A negative value means to flush immediately - // after each write to the client. - // The FlushInterval is ignored when ReverseProxy - // recognizes a response as a streaming response; - // for such responses, writes are flushed to the client - // immediately. - FlushInterval time.Duration - - // ErrorLog specifies an optional logger for errors - // that occur when attempting to proxy the request. - // If nil, logging goes to os.Stderr via the log package's - // standard logger. - ErrorLog *log.Logger - - // BufferPool optionally specifies a buffer pool to - // get byte slices for use by io.CopyBuffer when - // copying HTTP response bodies. - BufferPool BufferPool - - // ModifyResponse is an optional function that modifies the - // Response from the backend. It is called if the backend - // returns a response at all, with any HTTP status code. - // If the backend is unreachable, the optional ErrorHandler is - // called without any call to ModifyResponse. - // - // If ModifyResponse returns an error, ErrorHandler is called - // with its error value. If ErrorHandler is nil, its default - // implementation is used. - ModifyResponse func(*http.Response) error - - // ErrorHandler is an optional function that handles errors - // reaching the backend or errors from ModifyResponse. - // - // If nil, the default is to log the provided error and return - // a 502 Status Bad Gateway response. - ErrorHandler func(http.ResponseWriter, *http.Request, error) +func init() { + caddy.RegisterModule(Handler{}) } -// A BufferPool is an interface for getting and returning temporary -// byte slices for use by io.CopyBuffer. -type BufferPool interface { - Get() []byte - Put([]byte) +// Handler implements a highly configurable and production-ready reverse proxy. +type Handler struct { + TransportRaw json.RawMessage `json:"transport,omitempty"` + CBRaw json.RawMessage `json:"circuit_breaker,omitempty"` + LoadBalancing *LoadBalancing `json:"load_balancing,omitempty"` + HealthChecks *HealthChecks `json:"health_checks,omitempty"` + Upstreams UpstreamPool `json:"upstreams,omitempty"` + FlushInterval caddy.Duration `json:"flush_interval,omitempty"` + + Transport http.RoundTripper `json:"-"` + CB CircuitBreaker `json:"-"` } -func singleJoiningSlash(a, b string) string { - aslash := strings.HasSuffix(a, "/") - bslash := strings.HasPrefix(b, "/") - switch { - case aslash && bslash: - return a + b[1:] - case !aslash && !bslash: - return a + "/" + b +// CaddyModule returns the Caddy module information. +func (Handler) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy", + New: func() caddy.Module { return new(Handler) }, } - return a + b } -// NewSingleHostReverseProxy returns a new ReverseProxy that routes -// URLs to the scheme, host, and base path provided in target. If the -// target's path is "/base" and the incoming request was for "/dir", -// the target request will be for /base/dir. -// NewSingleHostReverseProxy does not rewrite the Host header. -// To rewrite Host headers, use ReverseProxy directly with a custom -// Director policy. -func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { - targetQuery := target.RawQuery - director := func(req *http.Request) { - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host - req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) - if targetQuery == "" || req.URL.RawQuery == "" { - req.URL.RawQuery = targetQuery + req.URL.RawQuery - } else { - req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery +// Provision ensures that h is set up properly before use. +func (h *Handler) Provision(ctx caddy.Context) error { + // start by loading modules + if h.TransportRaw != nil { + val, err := ctx.LoadModuleInline("protocol", "http.handlers.reverse_proxy.transport", h.TransportRaw) + if err != nil { + return fmt.Errorf("loading transport module: %s", err) } - if _, ok := req.Header["User-Agent"]; !ok { - // explicitly disable User-Agent so it's not set to default value - req.Header.Set("User-Agent", "") + h.Transport = val.(http.RoundTripper) + h.TransportRaw = nil // allow GC to deallocate - TODO: Does this help? + } + if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil { + val, err := ctx.LoadModuleInline("policy", + "http.handlers.reverse_proxy.selection_policies", + h.LoadBalancing.SelectionPolicyRaw) + if err != nil { + return fmt.Errorf("loading load balancing selection module: %s", err) } + h.LoadBalancing.SelectionPolicy = val.(Selector) + h.LoadBalancing.SelectionPolicyRaw = nil // allow GC to deallocate - TODO: Does this help? + } + if h.CBRaw != nil { + val, err := ctx.LoadModuleInline("type", "http.handlers.reverse_proxy.circuit_breakers", h.CBRaw) + if err != nil { + return fmt.Errorf("loading circuit breaker module: %s", err) + } + h.CB = val.(CircuitBreaker) + h.CBRaw = nil // allow GC to deallocate - TODO: Does this help? } - return &ReverseProxy{Director: director} -} -func copyHeader(dst, src http.Header) { - for k, vv := range src { - for _, v := range vv { - dst.Add(k, v) + if h.Transport == nil { + t := &HTTPTransport{ + KeepAlive: &KeepAlive{ + ProbeInterval: caddy.Duration(30 * time.Second), + IdleConnTimeout: caddy.Duration(2 * time.Minute), + }, + DialTimeout: caddy.Duration(10 * time.Second), } + err := t.Provision(ctx) + if err != nil { + return fmt.Errorf("provisioning default transport: %v", err) + } + h.Transport = t } -} -func cloneHeader(h http.Header) http.Header { - h2 := make(http.Header, len(h)) - for k, vv := range h { - vv2 := make([]string, len(vv)) - copy(vv2, vv) - h2[k] = vv2 + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + if h.LoadBalancing.SelectionPolicy == nil { + h.LoadBalancing.SelectionPolicy = RandomSelection{} + } + if h.LoadBalancing.TryDuration > 0 && h.LoadBalancing.TryInterval == 0 { + // a non-zero try_duration with a zero try_interval + // will always spin the CPU for try_duration if the + // upstream is local or low-latency; avoid that by + // defaulting to a sane wait period between attempts + h.LoadBalancing.TryInterval = caddy.Duration(250 * time.Millisecond) } - return h2 -} -// Hop-by-hop headers. These are removed when sent to the backend. -// As of RFC 7230, hop-by-hop headers are required to appear in the -// Connection header field. These are the headers defined by the -// obsoleted RFC 2616 (section 13.5.1) and are used for backward -// compatibility. -var hopHeaders = []string{ - "Connection", - "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Authorization", - "Te", // canonicalized version of "TE" - "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 - "Transfer-Encoding", - "Upgrade", -} + // if active health checks are enabled, configure them and start a worker + if h.HealthChecks != nil && + h.HealthChecks.Active != nil && + (h.HealthChecks.Active.Path != "" || h.HealthChecks.Active.Port != 0) { + timeout := time.Duration(h.HealthChecks.Active.Timeout) + if timeout == 0 { + timeout = 10 * time.Second + } -func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) { - p.logf("http: proxy error: %v", err) - rw.WriteHeader(http.StatusBadGateway) -} + h.HealthChecks.Active.stopChan = make(chan struct{}) + h.HealthChecks.Active.httpClient = &http.Client{ + Timeout: timeout, + Transport: h.Transport, + } + + if h.HealthChecks.Active.Interval == 0 { + h.HealthChecks.Active.Interval = caddy.Duration(30 * time.Second) + } + + if h.HealthChecks.Active.ExpectBody != "" { + var err error + h.HealthChecks.Active.bodyRegexp, err = regexp.Compile(h.HealthChecks.Active.ExpectBody) + if err != nil { + return fmt.Errorf("expect_body: compiling regular expression: %v", err) + } + } + + go h.activeHealthChecker() + } -func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) { - if p.ErrorHandler != nil { - return p.ErrorHandler + var allUpstreams []*Upstream + for _, upstream := range h.Upstreams { + // upstreams are allowed to map to only a single host, + // but an upstream's address may semantically represent + // multiple addresses, so make sure to handle each + // one in turn based on this one upstream config + network, addresses, err := caddy.ParseNetworkAddress(upstream.Dial) + if err != nil { + return fmt.Errorf("parsing dial address: %v", err) + } + + for _, addr := range addresses { + // make a new upstream based on the original + // that has a singular dial address + upstreamCopy := *upstream + upstreamCopy.dialInfo = DialInfo{network, addr} + upstreamCopy.Dial = upstreamCopy.dialInfo.String() + upstreamCopy.cb = h.CB + + // if host already exists from a current config, + // use that instead; otherwise, add it + // TODO: make hosts modular, so that their state can be distributed in enterprise for example + // TODO: If distributed, the pool should be stored in storage... + var host Host = new(upstreamHost) + activeHost, loaded := hosts.LoadOrStore(upstreamCopy.Dial, host) + if loaded { + host = activeHost.(Host) + } + upstreamCopy.Host = host + + // if the passive health checker has a non-zero "unhealthy + // request count" but the upstream has no MaxRequests set + // (they are the same thing, but one is a default value for + // for upstreams with a zero MaxRequests), copy the default + // value into this upstream, since the value in the upstream + // (MaxRequests) is what is used during availability checks + if h.HealthChecks != nil && + h.HealthChecks.Passive != nil && + h.HealthChecks.Passive.UnhealthyRequestCount > 0 && + upstreamCopy.MaxRequests == 0 { + upstreamCopy.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount + } + + // upstreams need independent access to the passive + // health check policy because they run outside of the + // scope of a request handler + if h.HealthChecks != nil { + upstreamCopy.healthCheckPolicy = h.HealthChecks.Passive + } + + allUpstreams = append(allUpstreams, &upstreamCopy) + } } - return p.defaultErrorHandler + + // replace the unmarshaled upstreams (possible 1:many + // address mapping) with our list, which is mapped 1:1, + // thus may have expanded the original list + h.Upstreams = allUpstreams + + return nil } -// modifyResponse conditionally runs the optional ModifyResponse hook -// and reports whether the request should proceed. -func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool { - if p.ModifyResponse == nil { - return true +// Cleanup cleans up the resources made by h during provisioning. +func (h *Handler) Cleanup() error { + // stop the active health checker + if h.HealthChecks != nil && + h.HealthChecks.Active != nil && + h.HealthChecks.Active.stopChan != nil { + close(h.HealthChecks.Active.stopChan) } - if err := p.ModifyResponse(res); err != nil { - res.Body.Close() - p.getErrorHandler()(rw, req, err) - return false + + // remove hosts from our config from the pool + for _, upstream := range h.Upstreams { + hosts.Delete(upstream.dialInfo.String()) } - return true + + return nil } -func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*http.Response, error) { - transport := p.Transport - if transport == nil { - transport = http.DefaultTransport - } - - ctx := req.Context() - if cn, ok := rw.(http.CloseNotifier); ok { - var cancel context.CancelFunc - ctx, cancel = context.WithCancel(ctx) - defer cancel() - notifyChan := cn.CloseNotify() - go func() { - select { - case <-notifyChan: - cancel() - case <-ctx.Done(): +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { + // prepare the request for proxying; this is needed only once + err := h.prepareRequest(r) + if err != nil { + return caddyhttp.Error(http.StatusInternalServerError, + fmt.Errorf("preparing request for upstream round-trip: %v", err)) + } + + start := time.Now() + + var proxyErr error + for { + // choose an available upstream + upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, r) + if upstream == nil { + if proxyErr == nil { + proxyErr = fmt.Errorf("no available upstreams") + } + if !h.tryAgain(start, proxyErr) { + break } - }() + continue + } + + // attach to the request information about how to dial the upstream; + // this is necessary because the information cannot be sufficiently + // or satisfactorily represented in a URL + ctx := context.WithValue(r.Context(), DialInfoCtxKey, upstream.dialInfo) + r = r.WithContext(ctx) + + // proxy the request to that upstream + proxyErr = h.reverseProxy(w, r, upstream) + if proxyErr == nil || proxyErr == context.Canceled { + // context.Canceled happens when the downstream client + // cancels the request; we don't have to worry about that + return nil + } + + // remember this failure (if enabled) + h.countFailure(upstream) + + // if we've tried long enough, break + if !h.tryAgain(start, proxyErr) { + break + } + } + + return caddyhttp.Error(http.StatusBadGateway, proxyErr) +} + +// prepareRequest modifies req so that it is ready to be proxied, +// except for directing to a specific upstream. This method mutates +// headers and other necessary properties of the request and should +// be done just once (before proxying) regardless of proxy retries. +// This assumes that no mutations of the request are performed +// by h during or after proxying. +func (h Handler) prepareRequest(req *http.Request) error { + // as a special (but very common) case, if the transport + // is HTTP, then ensure the request has the proper scheme + // because incoming requests by default are lacking it + if req.URL.Scheme == "" { + req.URL.Scheme = "http" + if ht, ok := h.Transport.(*HTTPTransport); ok && ht.TLS != nil { + req.URL.Scheme = "https" + } } - outreq := req.WithContext(ctx) // includes shallow copies of maps, but okay if req.ContentLength == 0 { - outreq.Body = nil // Issue 16036: nil Body for http.Transport retries + req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries } - outreq.Header = cloneHeader(req.Header) + req.Close = false - p.Director(outreq) - outreq.Close = false + // if User-Agent is not set by client, then explicitly + // disable it so it's not set to default value by std lib + if _, ok := req.Header["User-Agent"]; !ok { + req.Header.Set("User-Agent", "") + } - reqUpType := upgradeType(outreq.Header) - removeConnectionHeaders(outreq.Header) + reqUpType := upgradeType(req.Header) + removeConnectionHeaders(req.Header) // Remove hop-by-hop headers to the backend. Especially // important is "Connection" because we want a persistent // connection, regardless of what the client sent to us. for _, h := range hopHeaders { - hv := outreq.Header.Get(h) + hv := req.Header.Get(h) if hv == "" { continue } if h == "Te" && hv == "trailers" { - // Issue 21096: tell backend applications that + // Issue golang/go#21096: tell backend applications that // care about trailer support that we support // trailers. (We do, but we don't go out of // our way to advertise that unless the @@ -244,40 +321,72 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht // worth mentioning) continue } - outreq.Header.Del(h) + req.Header.Del(h) } // After stripping all the hop-by-hop connection headers above, add back any // necessary for protocol upgrades, such as for websockets. if reqUpType != "" { - outreq.Header.Set("Connection", "Upgrade") - outreq.Header.Set("Upgrade", reqUpType) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", reqUpType) } if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { // If we aren't the first proxy retain prior // X-Forwarded-For information as a comma+space // separated list and fold multiple headers into one. - if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + if prior, ok := req.Header["X-Forwarded-For"]; ok { clientIP = strings.Join(prior, ", ") + ", " + clientIP } - outreq.Header.Set("X-Forwarded-For", clientIP) + req.Header.Set("X-Forwarded-For", clientIP) } - res, err := transport.RoundTrip(outreq) + return nil +} + +// reverseProxy performs a round-trip to the given backend and processes the response with the client. +// (This method is mostly the beginning of what was borrowed from the net/http/httputil package in the +// Go standard library which was used as the foundation.) +func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, upstream *Upstream) error { + upstream.Host.CountRequest(1) + defer upstream.Host.CountRequest(-1) + + // point the request to this upstream + h.directRequest(req, upstream) + + // do the round-trip + start := time.Now() + res, err := h.Transport.RoundTrip(req) + latency := time.Since(start) if err != nil { - p.getErrorHandler()(rw, outreq, err) - return nil, err + return err } - // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) - if res.StatusCode == http.StatusSwitchingProtocols { - if !p.modifyResponse(rw, res, outreq) { - return res, nil + // update circuit breaker on current conditions + if upstream.cb != nil { + upstream.cb.RecordMetric(res.StatusCode, latency) + } + + // perform passive health checks (if enabled) + if h.HealthChecks != nil && h.HealthChecks.Passive != nil { + // strike if the status code matches one that is "bad" + for _, badStatus := range h.HealthChecks.Passive.UnhealthyStatus { + if caddyhttp.StatusCodeMatches(res.StatusCode, badStatus) { + h.countFailure(upstream) + } } - p.handleUpgradeResponse(rw, outreq, res) - return res, nil + // strike if the roundtrip took too long + if h.HealthChecks.Passive.UnhealthyLatency > 0 && + latency >= time.Duration(h.HealthChecks.Passive.UnhealthyLatency) { + h.countFailure(upstream) + } + } + + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) + if res.StatusCode == http.StatusSwitchingProtocols { + h.handleUpgradeResponse(rw, req, res) + return nil } removeConnectionHeaders(res.Header) @@ -286,10 +395,6 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht res.Header.Del(h) } - if !p.modifyResponse(rw, res, outreq) { - return res, nil - } - copyHeader(rw.Header(), res.Header) // The "Trailer" header isn't included in the Transport's response, @@ -305,15 +410,16 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht rw.WriteHeader(res.StatusCode) - err = p.copyResponse(rw, res.Body, p.flushInterval(req, res)) + err = h.copyResponse(rw, res.Body, h.flushInterval(req, res)) if err != nil { defer res.Body.Close() // Since we're streaming the response, if we run into an error all we can do - // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler + // is abort the request. Issue golang/go#23643: ReverseProxy should use ErrAbortHandler // on read error while copying body. + // TODO: Look into whether we want to panic at all in our case... if !shouldPanicOnCopyError(req) { - p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) - return nil, err + // p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) + return err } panic(http.ErrAbortHandler) @@ -331,7 +437,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht if len(res.Trailer) == announcedTrailers { copyHeader(rw.Header(), res.Trailer) - return res, nil + return nil } for k, vv := range res.Trailer { @@ -341,21 +447,48 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht } } - return res, nil + return nil +} + +// tryAgain takes the time that the handler was initially invoked +// as well as any error currently obtained and returns true if +// another attempt should be made at proxying the request. If +// true is returned, it has already blocked long enough before +// the next retry (i.e. no more sleeping is needed). If false is +// returned, the handler should stop trying to proxy the request. +func (h Handler) tryAgain(start time.Time, proxyErr error) bool { + // if downstream has canceled the request, break + if proxyErr == context.Canceled { + return false + } + // if we've tried long enough, break + if time.Since(start) >= time.Duration(h.LoadBalancing.TryDuration) { + return false + } + // otherwise, wait and try the next available host + time.Sleep(time.Duration(h.LoadBalancing.TryInterval)) + return true } -var inOurTests bool // whether we're in our own tests +// directRequest modifies only req.URL so that it points to the +// given upstream host. It must modify ONLY the request URL. +func (h Handler) directRequest(req *http.Request, upstream *Upstream) { + if req.URL.Host == "" { + req.URL.Host = upstream.dialInfo.Address + } +} // shouldPanicOnCopyError reports whether the reverse proxy should // panic with http.ErrAbortHandler. This is the right thing to do by // default, but Go 1.10 and earlier did not, so existing unit tests // weren't expecting panics. Only panic in our own tests, or when // running under the HTTP server. +// TODO: I don't know if we want this at all... func shouldPanicOnCopyError(req *http.Request) bool { - if inOurTests { - // Our tests know to handle this panic. - return true - } + // if inOurTests { + // // Our tests know to handle this panic. + // return true + // } if req.Context().Value(http.ServerContextKey) != nil { // We seem to be running under an HTTP server, so // it'll recover the panic. @@ -366,146 +499,22 @@ func shouldPanicOnCopyError(req *http.Request) bool { return false } -// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. -// See RFC 7230, section 6.1 -func removeConnectionHeaders(h http.Header) { - if c := h.Get("Connection"); c != "" { - for _, f := range strings.Split(c, ",") { - if f = strings.TrimSpace(f); f != "" { - h.Del(f) - } - } - } -} - -// flushInterval returns the p.FlushInterval value, conditionally -// overriding its value for a specific request/response. -func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration { - resCT := res.Header.Get("Content-Type") - - // For Server-Sent Events responses, flush immediately. - // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream - if resCT == "text/event-stream" { - return -1 // negative means immediately - } - - // TODO: more specific cases? e.g. res.ContentLength == -1? - return p.FlushInterval -} - -func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { - if flushInterval != 0 { - if wf, ok := dst.(writeFlusher); ok { - mlw := &maxLatencyWriter{ - dst: wf, - latency: flushInterval, - } - defer mlw.stop() - dst = mlw - } - } - - var buf []byte - if p.BufferPool != nil { - buf = p.BufferPool.Get() - defer p.BufferPool.Put(buf) - } - _, err := p.copyBuffer(dst, src, buf) - return err -} - -// copyBuffer returns any write errors or non-EOF read errors, and the amount -// of bytes written. -func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { - if len(buf) == 0 { - buf = make([]byte, 32*1024) - } - var written int64 - for { - nr, rerr := src.Read(buf) - if rerr != nil && rerr != io.EOF && rerr != context.Canceled { - p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) - } - if nr > 0 { - nw, werr := dst.Write(buf[:nr]) - if nw > 0 { - written += int64(nw) - } - if werr != nil { - return written, werr - } - if nr != nw { - return written, io.ErrShortWrite - } - } - if rerr != nil { - if rerr == io.EOF { - rerr = nil - } - return written, rerr +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) } } } -func (p *ReverseProxy) logf(format string, args ...interface{}) { - if p.ErrorLog != nil { - p.ErrorLog.Printf(format, args...) - } else { - log.Printf(format, args...) - } -} - -type writeFlusher interface { - io.Writer - http.Flusher -} - -type maxLatencyWriter struct { - dst writeFlusher - latency time.Duration // non-zero; negative means to flush immediately - - mu sync.Mutex // protects t, flushPending, and dst.Flush - t *time.Timer - flushPending bool -} - -func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { - m.mu.Lock() - defer m.mu.Unlock() - n, err = m.dst.Write(p) - if m.latency < 0 { - m.dst.Flush() - return - } - if m.flushPending { - return - } - if m.t == nil { - m.t = time.AfterFunc(m.latency, m.delayedFlush) - } else { - m.t.Reset(m.latency) - } - m.flushPending = true - return -} - -func (m *maxLatencyWriter) delayedFlush() { - m.mu.Lock() - defer m.mu.Unlock() - if !m.flushPending { // if stop was called but AfterFunc already started this goroutine - return - } - m.dst.Flush() - m.flushPending = false -} - -func (m *maxLatencyWriter) stop() { - m.mu.Lock() - defer m.mu.Unlock() - m.flushPending = false - if m.t != nil { - m.t.Stop() +func cloneHeader(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 } + return h2 } func upgradeType(h http.Header) string { @@ -515,62 +524,71 @@ func upgradeType(h http.Header) string { return strings.ToLower(h.Get("Upgrade")) } -func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { - reqUpType := upgradeType(req.Header) - resUpType := upgradeType(res.Header) - if reqUpType != resUpType { - p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) - return +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b } + return a + b +} - copyHeader(res.Header, rw.Header()) - - hj, ok := rw.(http.Hijacker) - if !ok { - p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) - return - } - backConn, ok := res.Body.(io.ReadWriteCloser) - if !ok { - p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) - return +// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. +// See RFC 7230, section 6.1 +func removeConnectionHeaders(h http.Header) { + if c := h.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + h.Del(f) + } + } } - defer backConn.Close() - conn, brw, err := hj.Hijack() - if err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) - return - } - defer conn.Close() - res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above - if err := res.Write(brw); err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) - return - } - if err := brw.Flush(); err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) - return - } - errc := make(chan error, 1) - spc := switchProtocolCopier{user: conn, backend: backConn} - go spc.copyToBackend(errc) - go spc.copyFromBackend(errc) - <-errc - return } -// switchProtocolCopier exists so goroutines proxying data back and -// forth have nice names in stacks. -type switchProtocolCopier struct { - user, backend io.ReadWriter +// LoadBalancing has parameters related to load balancing. +type LoadBalancing struct { + SelectionPolicyRaw json.RawMessage `json:"selection_policy,omitempty"` + TryDuration caddy.Duration `json:"try_duration,omitempty"` + TryInterval caddy.Duration `json:"try_interval,omitempty"` + + SelectionPolicy Selector `json:"-"` } -func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { - _, err := io.Copy(c.user, c.backend) - errc <- err +// Selector selects an available upstream from the pool. +type Selector interface { + Select(UpstreamPool, *http.Request) *Upstream } -func (c switchProtocolCopier) copyToBackend(errc chan<- error) { - _, err := io.Copy(c.backend, c.user) - errc <- err +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", } + +// TODO: see if we can use this +// var bufPool = sync.Pool{ +// New: func() interface{} { +// return new(bytes.Buffer) +// }, +// } + +// Interface guards +var ( + _ caddy.Provisioner = (*Handler)(nil) + _ caddy.CleanerUpper = (*Handler)(nil) + _ caddyhttp.MiddlewareHandler = (*Handler)(nil) +) |