summaryrefslogtreecommitdiff
path: root/modules/caddyhttp/reverseproxy
diff options
context:
space:
mode:
Diffstat (limited to 'modules/caddyhttp/reverseproxy')
-rw-r--r--modules/caddyhttp/reverseproxy/caddyfile.go486
-rw-r--r--modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go54
-rw-r--r--modules/caddyhttp/reverseproxy/fastcgi/client.go578
-rw-r--r--modules/caddyhttp/reverseproxy/fastcgi/client_test.go301
-rw-r--r--modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go301
-rwxr-xr-xmodules/caddyhttp/reverseproxy/healthchecker.go86
-rw-r--r--modules/caddyhttp/reverseproxy/healthchecks.go270
-rw-r--r--modules/caddyhttp/reverseproxy/hosts.go193
-rw-r--r--modules/caddyhttp/reverseproxy/httptransport.go208
-rwxr-xr-xmodules/caddyhttp/reverseproxy/module.go53
-rw-r--r--[-rwxr-xr-x]modules/caddyhttp/reverseproxy/reverseproxy.go784
-rw-r--r--modules/caddyhttp/reverseproxy/selectionpolicies.go353
-rw-r--r--modules/caddyhttp/reverseproxy/selectionpolicies_test.go273
-rw-r--r--modules/caddyhttp/reverseproxy/streaming.go223
-rwxr-xr-xmodules/caddyhttp/reverseproxy/upstream.go450
15 files changed, 3641 insertions, 972 deletions
diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go
new file mode 100644
index 0000000..ffa3ca0
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/caddyfile.go
@@ -0,0 +1,486 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reverseproxy
+
+import (
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/caddyserver/caddy/v2"
+ "github.com/caddyserver/caddy/v2/caddyconfig"
+ "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
+ "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
+ "github.com/caddyserver/caddy/v2/modules/caddyhttp"
+ "github.com/dustin/go-humanize"
+)
+
+func init() {
+ httpcaddyfile.RegisterHandlerDirective("reverse_proxy", parseCaddyfile)
+}
+
+func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
+ rp := new(Handler)
+ err := rp.UnmarshalCaddyfile(h.Dispenser)
+ if err != nil {
+ return nil, err
+ }
+ return rp, nil
+}
+
+// UnmarshalCaddyfile sets up the handler from Caddyfile tokens. Syntax:
+//
+// reverse_proxy [<matcher>] [<upstreams...>] {
+// # upstreams
+// to <upstreams...>
+//
+// # load balancing
+// lb_policy <name> [<options...>]
+// lb_try_duration <duration>
+// lb_try_interval <interval>
+//
+// # active health checking
+// health_path <path>
+// health_port <port>
+// health_interval <interval>
+// health_timeout <duration>
+// health_status <status>
+// health_body <regexp>
+//
+// # passive health checking
+// max_fails <num>
+// fail_duration <duration>
+// max_conns <num>
+// unhealthy_status <status>
+// unhealthy_latency <duration>
+//
+// # round trip
+// transport <name> {
+// ...
+// }
+// }
+//
+func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
+ for d.Next() {
+ for _, up := range d.RemainingArgs() {
+ h.Upstreams = append(h.Upstreams, &Upstream{
+ Dial: up,
+ })
+ }
+
+ for d.NextBlock() {
+ switch d.Val() {
+ case "to":
+ args := d.RemainingArgs()
+ if len(args) == 0 {
+ return d.ArgErr()
+ }
+ for _, up := range args {
+ h.Upstreams = append(h.Upstreams, &Upstream{
+ Dial: up,
+ })
+ }
+
+ case "lb_policy":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil {
+ return d.Err("load balancing selection policy already specified")
+ }
+ name := d.Val()
+ mod, err := caddy.GetModule("http.handlers.reverse_proxy.selection_policies." + name)
+ if err != nil {
+ return d.Errf("getting load balancing policy module '%s': %v", mod.Name, err)
+ }
+ unm, ok := mod.New().(caddyfile.Unmarshaler)
+ if !ok {
+ return d.Errf("load balancing policy module '%s' is not a Caddyfile unmarshaler", mod.Name)
+ }
+ err = unm.UnmarshalCaddyfile(d.NewFromNextTokens())
+ if err != nil {
+ return err
+ }
+ sel, ok := unm.(Selector)
+ if !ok {
+ return d.Errf("module %s is not a Selector", mod.Name)
+ }
+ if h.LoadBalancing == nil {
+ h.LoadBalancing = new(LoadBalancing)
+ }
+ h.LoadBalancing.SelectionPolicyRaw = caddyconfig.JSONModuleObject(sel, "policy", name, nil)
+
+ case "lb_try_duration":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.LoadBalancing == nil {
+ h.LoadBalancing = new(LoadBalancing)
+ }
+ dur, err := time.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad duration value %s: %v", d.Val(), err)
+ }
+ h.LoadBalancing.TryDuration = caddy.Duration(dur)
+
+ case "lb_try_interval":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.LoadBalancing == nil {
+ h.LoadBalancing = new(LoadBalancing)
+ }
+ dur, err := time.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad interval value '%s': %v", d.Val(), err)
+ }
+ h.LoadBalancing.TryInterval = caddy.Duration(dur)
+
+ case "health_path":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Active == nil {
+ h.HealthChecks.Active = new(ActiveHealthChecks)
+ }
+ h.HealthChecks.Active.Path = d.Val()
+
+ case "health_port":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Active == nil {
+ h.HealthChecks.Active = new(ActiveHealthChecks)
+ }
+ portNum, err := strconv.Atoi(d.Val())
+ if err != nil {
+ return d.Errf("bad port number '%s': %v", d.Val(), err)
+ }
+ h.HealthChecks.Active.Port = portNum
+
+ case "health_interval":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Active == nil {
+ h.HealthChecks.Active = new(ActiveHealthChecks)
+ }
+ dur, err := time.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad interval value %s: %v", d.Val(), err)
+ }
+ h.HealthChecks.Active.Interval = caddy.Duration(dur)
+
+ case "health_timeout":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Active == nil {
+ h.HealthChecks.Active = new(ActiveHealthChecks)
+ }
+ dur, err := time.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad timeout value %s: %v", d.Val(), err)
+ }
+ h.HealthChecks.Active.Timeout = caddy.Duration(dur)
+
+ case "health_status":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Active == nil {
+ h.HealthChecks.Active = new(ActiveHealthChecks)
+ }
+ val := d.Val()
+ if len(val) == 3 && strings.HasSuffix(val, "xx") {
+ val = val[:1]
+ }
+ statusNum, err := strconv.Atoi(val[:1])
+ if err != nil {
+ return d.Errf("bad status value '%s': %v", d.Val(), err)
+ }
+ h.HealthChecks.Active.ExpectStatus = statusNum
+
+ case "health_body":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Active == nil {
+ h.HealthChecks.Active = new(ActiveHealthChecks)
+ }
+ h.HealthChecks.Active.ExpectBody = d.Val()
+
+ case "max_fails":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Passive == nil {
+ h.HealthChecks.Passive = new(PassiveHealthChecks)
+ }
+ maxFails, err := strconv.Atoi(d.Val())
+ if err != nil {
+ return d.Errf("invalid maximum fail count '%s': %v", d.Val(), err)
+ }
+ h.HealthChecks.Passive.MaxFails = maxFails
+
+ case "fail_duration":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Passive == nil {
+ h.HealthChecks.Passive = new(PassiveHealthChecks)
+ }
+ dur, err := time.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad duration value '%s': %v", d.Val(), err)
+ }
+ h.HealthChecks.Passive.FailDuration = caddy.Duration(dur)
+
+ case "unhealthy_request_count":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Passive == nil {
+ h.HealthChecks.Passive = new(PassiveHealthChecks)
+ }
+ maxConns, err := strconv.Atoi(d.Val())
+ if err != nil {
+ return d.Errf("invalid maximum connection count '%s': %v", d.Val(), err)
+ }
+ h.HealthChecks.Passive.UnhealthyRequestCount = maxConns
+
+ case "unhealthy_status":
+ args := d.RemainingArgs()
+ if len(args) == 0 {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Passive == nil {
+ h.HealthChecks.Passive = new(PassiveHealthChecks)
+ }
+ for _, arg := range args {
+ if len(arg) == 3 && strings.HasSuffix(arg, "xx") {
+ arg = arg[:1]
+ }
+ statusNum, err := strconv.Atoi(arg[:1])
+ if err != nil {
+ return d.Errf("bad status value '%s': %v", d.Val(), err)
+ }
+ h.HealthChecks.Passive.UnhealthyStatus = append(h.HealthChecks.Passive.UnhealthyStatus, statusNum)
+ }
+
+ case "unhealthy_latency":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.HealthChecks == nil {
+ h.HealthChecks = new(HealthChecks)
+ }
+ if h.HealthChecks.Passive == nil {
+ h.HealthChecks.Passive = new(PassiveHealthChecks)
+ }
+ dur, err := time.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad duration value '%s': %v", d.Val(), err)
+ }
+ h.HealthChecks.Passive.UnhealthyLatency = caddy.Duration(dur)
+
+ case "transport":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.TransportRaw != nil {
+ return d.Err("transport already specified")
+ }
+ name := d.Val()
+ mod, err := caddy.GetModule("http.handlers.reverse_proxy.transport." + name)
+ if err != nil {
+ return d.Errf("getting transport module '%s': %v", mod.Name, err)
+ }
+ unm, ok := mod.New().(caddyfile.Unmarshaler)
+ if !ok {
+ return d.Errf("transport module '%s' is not a Caddyfile unmarshaler", mod.Name)
+ }
+ d.Next() // consume the module name token
+ err = unm.UnmarshalCaddyfile(d.NewFromNextTokens())
+ if err != nil {
+ return err
+ }
+ rt, ok := unm.(http.RoundTripper)
+ if !ok {
+ return d.Errf("module %s is not a RoundTripper", mod.Name)
+ }
+ h.TransportRaw = caddyconfig.JSONModuleObject(rt, "protocol", name, nil)
+
+ default:
+ return d.Errf("unrecognized subdirective %s", d.Val())
+ }
+ }
+ }
+
+ return nil
+}
+
+// UnmarshalCaddyfile deserializes Caddyfile tokens into h.
+//
+// transport http {
+// read_buffer <size>
+// write_buffer <size>
+// dial_timeout <duration>
+// tls_client_auth <cert_file> <key_file>
+// tls_insecure_skip_verify
+// tls_timeout <duration>
+// keepalive [off|<duration>]
+// keepalive_idle_conns <max_count>
+// }
+//
+func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
+ for d.NextBlock() {
+ switch d.Val() {
+ case "read_buffer":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ size, err := humanize.ParseBytes(d.Val())
+ if err != nil {
+ return d.Errf("invalid read buffer size '%s': %v", d.Val(), err)
+ }
+ h.ReadBufferSize = int(size)
+
+ case "write_buffer":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ size, err := humanize.ParseBytes(d.Val())
+ if err != nil {
+ return d.Errf("invalid write buffer size '%s': %v", d.Val(), err)
+ }
+ h.WriteBufferSize = int(size)
+
+ case "dial_timeout":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ dur, err := time.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad timeout value '%s': %v", d.Val(), err)
+ }
+ h.DialTimeout = caddy.Duration(dur)
+
+ case "tls_client_auth":
+ args := d.RemainingArgs()
+ if len(args) != 2 {
+ return d.ArgErr()
+ }
+ if h.TLS == nil {
+ h.TLS = new(TLSConfig)
+ }
+ h.TLS.ClientCertificateFile = args[0]
+ h.TLS.ClientCertificateKeyFile = args[1]
+
+ case "tls_insecure_skip_verify":
+ if d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.TLS == nil {
+ h.TLS = new(TLSConfig)
+ }
+ h.TLS.InsecureSkipVerify = true
+
+ case "tls_timeout":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ dur, err := time.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad timeout value '%s': %v", d.Val(), err)
+ }
+ if h.TLS == nil {
+ h.TLS = new(TLSConfig)
+ }
+ h.TLS.HandshakeTimeout = caddy.Duration(dur)
+
+ case "keepalive":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ if h.KeepAlive == nil {
+ h.KeepAlive = new(KeepAlive)
+ }
+ if d.Val() == "off" {
+ var disable bool
+ h.KeepAlive.Enabled = &disable
+ }
+ dur, err := time.ParseDuration(d.Val())
+ if err != nil {
+ return d.Errf("bad duration value '%s': %v", d.Val(), err)
+ }
+ h.KeepAlive.IdleConnTimeout = caddy.Duration(dur)
+
+ case "keepalive_idle_conns":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ num, err := strconv.Atoi(d.Val())
+ if err != nil {
+ return d.Errf("bad integer value '%s': %v", d.Val(), err)
+ }
+ if h.KeepAlive == nil {
+ h.KeepAlive = new(KeepAlive)
+ }
+ h.KeepAlive.MaxIdleConns = num
+
+ default:
+ return d.Errf("unrecognized subdirective %s", d.Val())
+ }
+ }
+ return nil
+}
+
+// Interface guards
+var (
+ _ caddyfile.Unmarshaler = (*Handler)(nil)
+ _ caddyfile.Unmarshaler = (*HTTPTransport)(nil)
+)
diff --git a/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go b/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go
new file mode 100644
index 0000000..c8b9f63
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go
@@ -0,0 +1,54 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fastcgi
+
+import "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
+
+// UnmarshalCaddyfile deserializes Caddyfile tokens into h.
+//
+// transport fastcgi {
+// root <path>
+// split <at>
+// env <key> <value>
+// }
+//
+func (t *Transport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
+ for d.NextBlock() {
+ switch d.Val() {
+ case "root":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ t.Root = d.Val()
+
+ case "split":
+ if !d.NextArg() {
+ return d.ArgErr()
+ }
+ t.SplitPath = d.Val()
+
+ case "env":
+ args := d.RemainingArgs()
+ if len(args) != 2 {
+ return d.ArgErr()
+ }
+ t.EnvVars = append(t.EnvVars, [2]string{args[0], args[1]})
+
+ default:
+ return d.Errf("unrecognized subdirective %s", d.Val())
+ }
+ }
+ return nil
+}
diff --git a/modules/caddyhttp/reverseproxy/fastcgi/client.go b/modules/caddyhttp/reverseproxy/fastcgi/client.go
new file mode 100644
index 0000000..ae0de00
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/fastcgi/client.go
@@ -0,0 +1,578 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Forked Jan. 2015 from http://bitbucket.org/PinIdea/fcgi_client
+// (which is forked from https://code.google.com/p/go-fastcgi-client/).
+// This fork contains several fixes and improvements by Matt Holt and
+// other contributors to the Caddy project.
+
+// Copyright 2012 Junqing Tan <ivan@mysqlab.net> and The Go Authors
+// Use of this source code is governed by a BSD-style
+// Part of source code is from Go fcgi package
+
+package fastcgi
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/binary"
+ "errors"
+ "io"
+ "io/ioutil"
+ "mime/multipart"
+ "net"
+ "net/http"
+ "net/http/httputil"
+ "net/textproto"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+)
+
+// FCGIListenSockFileno describes listen socket file number.
+const FCGIListenSockFileno uint8 = 0
+
+// FCGIHeaderLen describes header length.
+const FCGIHeaderLen uint8 = 8
+
+// Version1 describes the version.
+const Version1 uint8 = 1
+
+// FCGINullRequestID describes the null request ID.
+const FCGINullRequestID uint8 = 0
+
+// FCGIKeepConn describes keep connection mode.
+const FCGIKeepConn uint8 = 1
+
+const (
+ // BeginRequest is the begin request flag.
+ BeginRequest uint8 = iota + 1
+ // AbortRequest is the abort request flag.
+ AbortRequest
+ // EndRequest is the end request flag.
+ EndRequest
+ // Params is the parameters flag.
+ Params
+ // Stdin is the standard input flag.
+ Stdin
+ // Stdout is the standard output flag.
+ Stdout
+ // Stderr is the standard error flag.
+ Stderr
+ // Data is the data flag.
+ Data
+ // GetValues is the get values flag.
+ GetValues
+ // GetValuesResult is the get values result flag.
+ GetValuesResult
+ // UnknownType is the unknown type flag.
+ UnknownType
+ // MaxType is the maximum type flag.
+ MaxType = UnknownType
+)
+
+const (
+ // Responder is the responder flag.
+ Responder uint8 = iota + 1
+ // Authorizer is the authorizer flag.
+ Authorizer
+ // Filter is the filter flag.
+ Filter
+)
+
+const (
+ // RequestComplete is the completed request flag.
+ RequestComplete uint8 = iota
+ // CantMultiplexConns is the multiplexed connections flag.
+ CantMultiplexConns
+ // Overloaded is the overloaded flag.
+ Overloaded
+ // UnknownRole is the unknown role flag.
+ UnknownRole
+)
+
+const (
+ // MaxConns is the maximum connections flag.
+ MaxConns string = "MAX_CONNS"
+ // MaxRequests is the maximum requests flag.
+ MaxRequests string = "MAX_REQS"
+ // MultiplexConns is the multiplex connections flag.
+ MultiplexConns string = "MPXS_CONNS"
+)
+
+const (
+ maxWrite = 65500 // 65530 may work, but for compatibility
+ maxPad = 255
+)
+
+type header struct {
+ Version uint8
+ Type uint8
+ ID uint16
+ ContentLength uint16
+ PaddingLength uint8
+ Reserved uint8
+}
+
+// for padding so we don't have to allocate all the time
+// not synchronized because we don't care what the contents are
+var pad [maxPad]byte
+
+func (h *header) init(recType uint8, reqID uint16, contentLength int) {
+ h.Version = 1
+ h.Type = recType
+ h.ID = reqID
+ h.ContentLength = uint16(contentLength)
+ h.PaddingLength = uint8(-contentLength & 7)
+}
+
+type record struct {
+ h header
+ rbuf []byte
+}
+
+func (rec *record) read(r io.Reader) (buf []byte, err error) {
+ if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil {
+ return
+ }
+ if rec.h.Version != 1 {
+ err = errors.New("fcgi: invalid header version")
+ return
+ }
+ if rec.h.Type == EndRequest {
+ err = io.EOF
+ return
+ }
+ n := int(rec.h.ContentLength) + int(rec.h.PaddingLength)
+ if len(rec.rbuf) < n {
+ rec.rbuf = make([]byte, n)
+ }
+ if _, err = io.ReadFull(r, rec.rbuf[:n]); err != nil {
+ return
+ }
+ buf = rec.rbuf[:int(rec.h.ContentLength)]
+
+ return
+}
+
+// FCGIClient implements a FastCGI client, which is a standard for
+// interfacing external applications with Web servers.
+type FCGIClient struct {
+ mutex sync.Mutex
+ rwc io.ReadWriteCloser
+ h header
+ buf bytes.Buffer
+ stderr bytes.Buffer
+ keepAlive bool
+ reqID uint16
+}
+
+// DialWithDialerContext connects to the fcgi responder at the specified network address, using custom net.Dialer
+// and a context.
+// See func net.Dial for a description of the network and address parameters.
+func DialWithDialerContext(ctx context.Context, network, address string, dialer net.Dialer) (fcgi *FCGIClient, err error) {
+ var conn net.Conn
+ conn, err = dialer.DialContext(ctx, network, address)
+ if err != nil {
+ return
+ }
+
+ fcgi = &FCGIClient{
+ rwc: conn,
+ keepAlive: false,
+ reqID: 1,
+ }
+
+ return
+}
+
+// DialContext is like Dial but passes ctx to dialer.Dial.
+func DialContext(ctx context.Context, network, address string) (fcgi *FCGIClient, err error) {
+ // TODO: why not set timeout here?
+ return DialWithDialerContext(ctx, network, address, net.Dialer{})
+}
+
+// Dial connects to the fcgi responder at the specified network address, using default net.Dialer.
+// See func net.Dial for a description of the network and address parameters.
+func Dial(network, address string) (fcgi *FCGIClient, err error) {
+ return DialContext(context.Background(), network, address)
+}
+
+// Close closes fcgi connection
+func (c *FCGIClient) Close() {
+ c.rwc.Close()
+}
+
+func (c *FCGIClient) writeRecord(recType uint8, content []byte) (err error) {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ c.buf.Reset()
+ c.h.init(recType, c.reqID, len(content))
+ if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil {
+ return err
+ }
+ if _, err := c.buf.Write(content); err != nil {
+ return err
+ }
+ if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil {
+ return err
+ }
+ _, err = c.rwc.Write(c.buf.Bytes())
+ return err
+}
+
+func (c *FCGIClient) writeBeginRequest(role uint16, flags uint8) error {
+ b := [8]byte{byte(role >> 8), byte(role), flags}
+ return c.writeRecord(BeginRequest, b[:])
+}
+
+func (c *FCGIClient) writeEndRequest(appStatus int, protocolStatus uint8) error {
+ b := make([]byte, 8)
+ binary.BigEndian.PutUint32(b, uint32(appStatus))
+ b[4] = protocolStatus
+ return c.writeRecord(EndRequest, b)
+}
+
+func (c *FCGIClient) writePairs(recType uint8, pairs map[string]string) error {
+ w := newWriter(c, recType)
+ b := make([]byte, 8)
+ nn := 0
+ for k, v := range pairs {
+ m := 8 + len(k) + len(v)
+ if m > maxWrite {
+ // param data size exceed 65535 bytes"
+ vl := maxWrite - 8 - len(k)
+ v = v[:vl]
+ }
+ n := encodeSize(b, uint32(len(k)))
+ n += encodeSize(b[n:], uint32(len(v)))
+ m = n + len(k) + len(v)
+ if (nn + m) > maxWrite {
+ w.Flush()
+ nn = 0
+ }
+ nn += m
+ if _, err := w.Write(b[:n]); err != nil {
+ return err
+ }
+ if _, err := w.WriteString(k); err != nil {
+ return err
+ }
+ if _, err := w.WriteString(v); err != nil {
+ return err
+ }
+ }
+ w.Close()
+ return nil
+}
+
+func encodeSize(b []byte, size uint32) int {
+ if size > 127 {
+ size |= 1 << 31
+ binary.BigEndian.PutUint32(b, size)
+ return 4
+ }
+ b[0] = byte(size)
+ return 1
+}
+
+// bufWriter encapsulates bufio.Writer but also closes the underlying stream when
+// Closed.
+type bufWriter struct {
+ closer io.Closer
+ *bufio.Writer
+}
+
+func (w *bufWriter) Close() error {
+ if err := w.Writer.Flush(); err != nil {
+ w.closer.Close()
+ return err
+ }
+ return w.closer.Close()
+}
+
+func newWriter(c *FCGIClient, recType uint8) *bufWriter {
+ s := &streamWriter{c: c, recType: recType}
+ w := bufio.NewWriterSize(s, maxWrite)
+ return &bufWriter{s, w}
+}
+
+// streamWriter abstracts out the separation of a stream into discrete records.
+// It only writes maxWrite bytes at a time.
+type streamWriter struct {
+ c *FCGIClient
+ recType uint8
+}
+
+func (w *streamWriter) Write(p []byte) (int, error) {
+ nn := 0
+ for len(p) > 0 {
+ n := len(p)
+ if n > maxWrite {
+ n = maxWrite
+ }
+ if err := w.c.writeRecord(w.recType, p[:n]); err != nil {
+ return nn, err
+ }
+ nn += n
+ p = p[n:]
+ }
+ return nn, nil
+}
+
+func (w *streamWriter) Close() error {
+ // send empty record to close the stream
+ return w.c.writeRecord(w.recType, nil)
+}
+
+type streamReader struct {
+ c *FCGIClient
+ buf []byte
+}
+
+func (w *streamReader) Read(p []byte) (n int, err error) {
+
+ if len(p) > 0 {
+ if len(w.buf) == 0 {
+
+ // filter outputs for error log
+ for {
+ rec := &record{}
+ var buf []byte
+ buf, err = rec.read(w.c.rwc)
+ if err != nil {
+ return
+ }
+ // standard error output
+ if rec.h.Type == Stderr {
+ w.c.stderr.Write(buf)
+ continue
+ }
+ w.buf = buf
+ break
+ }
+ }
+
+ n = len(p)
+ if n > len(w.buf) {
+ n = len(w.buf)
+ }
+ copy(p, w.buf[:n])
+ w.buf = w.buf[n:]
+ }
+
+ return
+}
+
+// Do made the request and returns a io.Reader that translates the data read
+// from fcgi responder out of fcgi packet before returning it.
+func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err error) {
+ err = c.writeBeginRequest(uint16(Responder), 0)
+ if err != nil {
+ return
+ }
+
+ err = c.writePairs(Params, p)
+ if err != nil {
+ return
+ }
+
+ body := newWriter(c, Stdin)
+ if req != nil {
+ _, _ = io.Copy(body, req)
+ }
+ body.Close()
+
+ r = &streamReader{c: c}
+ return
+}
+
+// clientCloser is a io.ReadCloser. It wraps a io.Reader with a Closer
+// that closes FCGIClient connection.
+type clientCloser struct {
+ *FCGIClient
+ io.Reader
+}
+
+func (f clientCloser) Close() error { return f.rwc.Close() }
+
+// Request returns a HTTP Response with Header and Body
+// from fcgi responder
+func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Response, err error) {
+ r, err := c.Do(p, req)
+ if err != nil {
+ return
+ }
+
+ rb := bufio.NewReader(r)
+ tp := textproto.NewReader(rb)
+ resp = new(http.Response)
+
+ // Parse the response headers.
+ mimeHeader, err := tp.ReadMIMEHeader()
+ if err != nil && err != io.EOF {
+ return
+ }
+ resp.Header = http.Header(mimeHeader)
+
+ if resp.Header.Get("Status") != "" {
+ statusParts := strings.SplitN(resp.Header.Get("Status"), " ", 2)
+ resp.StatusCode, err = strconv.Atoi(statusParts[0])
+ if err != nil {
+ return
+ }
+ if len(statusParts) > 1 {
+ resp.Status = statusParts[1]
+ }
+
+ } else {
+ resp.StatusCode = http.StatusOK
+ }
+
+ // TODO: fixTransferEncoding ?
+ resp.TransferEncoding = resp.Header["Transfer-Encoding"]
+ resp.ContentLength, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
+
+ if chunked(resp.TransferEncoding) {
+ resp.Body = clientCloser{c, httputil.NewChunkedReader(rb)}
+ } else {
+ resp.Body = clientCloser{c, ioutil.NopCloser(rb)}
+ }
+ return
+}
+
+// Get issues a GET request to the fcgi responder.
+func (c *FCGIClient) Get(p map[string]string, body io.Reader, l int64) (resp *http.Response, err error) {
+
+ p["REQUEST_METHOD"] = "GET"
+ p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10)
+
+ return c.Request(p, body)
+}
+
+// Head issues a HEAD request to the fcgi responder.
+func (c *FCGIClient) Head(p map[string]string) (resp *http.Response, err error) {
+
+ p["REQUEST_METHOD"] = "HEAD"
+ p["CONTENT_LENGTH"] = "0"
+
+ return c.Request(p, nil)
+}
+
+// Options issues an OPTIONS request to the fcgi responder.
+func (c *FCGIClient) Options(p map[string]string) (resp *http.Response, err error) {
+
+ p["REQUEST_METHOD"] = "OPTIONS"
+ p["CONTENT_LENGTH"] = "0"
+
+ return c.Request(p, nil)
+}
+
+// Post issues a POST request to the fcgi responder. with request body
+// in the format that bodyType specified
+func (c *FCGIClient) Post(p map[string]string, method string, bodyType string, body io.Reader, l int64) (resp *http.Response, err error) {
+ if p == nil {
+ p = make(map[string]string)
+ }
+
+ p["REQUEST_METHOD"] = strings.ToUpper(method)
+
+ if len(p["REQUEST_METHOD"]) == 0 || p["REQUEST_METHOD"] == "GET" {
+ p["REQUEST_METHOD"] = "POST"
+ }
+
+ p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10)
+ if len(bodyType) > 0 {
+ p["CONTENT_TYPE"] = bodyType
+ } else {
+ p["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
+ }
+
+ return c.Request(p, body)
+}
+
+// PostForm issues a POST to the fcgi responder, with form
+// as a string key to a list values (url.Values)
+func (c *FCGIClient) PostForm(p map[string]string, data url.Values) (resp *http.Response, err error) {
+ body := bytes.NewReader([]byte(data.Encode()))
+ return c.Post(p, "POST", "application/x-www-form-urlencoded", body, int64(body.Len()))
+}
+
+// PostFile issues a POST to the fcgi responder in multipart(RFC 2046) standard,
+// with form as a string key to a list values (url.Values),
+// and/or with file as a string key to a list file path.
+func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[string]string) (resp *http.Response, err error) {
+ buf := &bytes.Buffer{}
+ writer := multipart.NewWriter(buf)
+ bodyType := writer.FormDataContentType()
+
+ for key, val := range data {
+ for _, v0 := range val {
+ err = writer.WriteField(key, v0)
+ if err != nil {
+ return
+ }
+ }
+ }
+
+ for key, val := range file {
+ fd, e := os.Open(val)
+ if e != nil {
+ return nil, e
+ }
+ defer fd.Close()
+
+ part, e := writer.CreateFormFile(key, filepath.Base(val))
+ if e != nil {
+ return nil, e
+ }
+ _, err = io.Copy(part, fd)
+ if err != nil {
+ return
+ }
+ }
+
+ err = writer.Close()
+ if err != nil {
+ return
+ }
+
+ return c.Post(p, "POST", bodyType, buf, int64(buf.Len()))
+}
+
+// SetReadTimeout sets the read timeout for future calls that read from the
+// fcgi responder. A zero value for t means no timeout will be set.
+func (c *FCGIClient) SetReadTimeout(t time.Duration) error {
+ if conn, ok := c.rwc.(net.Conn); ok && t != 0 {
+ return conn.SetReadDeadline(time.Now().Add(t))
+ }
+ return nil
+}
+
+// SetWriteTimeout sets the write timeout for future calls that send data to
+// the fcgi responder. A zero value for t means no timeout will be set.
+func (c *FCGIClient) SetWriteTimeout(t time.Duration) error {
+ if conn, ok := c.rwc.(net.Conn); ok && t != 0 {
+ return conn.SetWriteDeadline(time.Now().Add(t))
+ }
+ return nil
+}
+
+// Checks whether chunked is part of the encodings stack
+func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }
diff --git a/modules/caddyhttp/reverseproxy/fastcgi/client_test.go b/modules/caddyhttp/reverseproxy/fastcgi/client_test.go
new file mode 100644
index 0000000..c090f3c
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/fastcgi/client_test.go
@@ -0,0 +1,301 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// NOTE: These tests were adapted from the original
+// repository from which this package was forked.
+// The tests are slow (~10s) and in dire need of rewriting.
+// As such, the tests have been disabled to speed up
+// automated builds until they can be properly written.
+
+package fastcgi
+
+import (
+ "bytes"
+ "crypto/md5"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "math/rand"
+ "net"
+ "net/http"
+ "net/http/fcgi"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+// test fcgi protocol includes:
+// Get, Post, Post in multipart/form-data, and Post with files
+// each key should be the md5 of the value or the file uploaded
+// specify remote fcgi responder ip:port to test with php
+// test failed if the remote fcgi(script) failed md5 verification
+// and output "FAILED" in response
+const (
+ scriptFile = "/tank/www/fcgic_test.php"
+ //ipPort = "remote-php-serv:59000"
+ ipPort = "127.0.0.1:59000"
+)
+
+var globalt *testing.T
+
+type FastCGIServer struct{}
+
+func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+
+ if err := req.ParseMultipartForm(100000000); err != nil {
+ log.Printf("[ERROR] failed to parse: %v", err)
+ }
+
+ stat := "PASSED"
+ fmt.Fprintln(resp, "-")
+ fileNum := 0
+ {
+ length := 0
+ for k0, v0 := range req.Form {
+ h := md5.New()
+ _, _ = io.WriteString(h, v0[0])
+ _md5 := fmt.Sprintf("%x", h.Sum(nil))
+
+ length += len(k0)
+ length += len(v0[0])
+
+ // echo error when key != _md5(val)
+ if _md5 != k0 {
+ fmt.Fprintln(resp, "server:err ", _md5, k0)
+ stat = "FAILED"
+ }
+ }
+ if req.MultipartForm != nil {
+ fileNum = len(req.MultipartForm.File)
+ for kn, fns := range req.MultipartForm.File {
+ //fmt.Fprintln(resp, "server:filekey ", kn )
+ length += len(kn)
+ for _, f := range fns {
+ fd, err := f.Open()
+ if err != nil {
+ log.Println("server:", err)
+ return
+ }
+ h := md5.New()
+ l0, err := io.Copy(h, fd)
+ if err != nil {
+ log.Println(err)
+ return
+ }
+ length += int(l0)
+ defer fd.Close()
+ md5 := fmt.Sprintf("%x", h.Sum(nil))
+ //fmt.Fprintln(resp, "server:filemd5 ", md5 )
+
+ if kn != md5 {
+ fmt.Fprintln(resp, "server:err ", md5, kn)
+ stat = "FAILED"
+ }
+ //fmt.Fprintln(resp, "server:filename ", f.Filename )
+ }
+ }
+ }
+
+ fmt.Fprintln(resp, "server:got data length", length)
+ }
+ fmt.Fprintln(resp, "-"+stat+"-POST(", len(req.Form), ")-FILE(", fileNum, ")--")
+}
+
+func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) {
+ fcgi, err := Dial("tcp", ipPort)
+ if err != nil {
+ log.Println("err:", err)
+ return
+ }
+
+ length := 0
+
+ var resp *http.Response
+ switch reqType {
+ case 0:
+ if len(data) > 0 {
+ length = len(data)
+ rd := bytes.NewReader(data)
+ resp, err = fcgi.Post(fcgiParams, "", "", rd, int64(rd.Len()))
+ } else if len(posts) > 0 {
+ values := url.Values{}
+ for k, v := range posts {
+ values.Set(k, v)
+ length += len(k) + 2 + len(v)
+ }
+ resp, err = fcgi.PostForm(fcgiParams, values)
+ } else {
+ rd := bytes.NewReader(data)
+ resp, err = fcgi.Get(fcgiParams, rd, int64(rd.Len()))
+ }
+
+ default:
+ values := url.Values{}
+ for k, v := range posts {
+ values.Set(k, v)
+ length += len(k) + 2 + len(v)
+ }
+
+ for k, v := range files {
+ fi, _ := os.Lstat(v)
+ length += len(k) + int(fi.Size())
+ }
+ resp, err = fcgi.PostFile(fcgiParams, values, files)
+ }
+
+ if err != nil {
+ log.Println("err:", err)
+ return
+ }
+
+ defer resp.Body.Close()
+ content, _ = ioutil.ReadAll(resp.Body)
+
+ log.Println("c: send data length ≈", length, string(content))
+ fcgi.Close()
+ time.Sleep(1 * time.Second)
+
+ if bytes.Contains(content, []byte("FAILED")) {
+ globalt.Error("Server return failed message")
+ }
+
+ return
+}
+
+func generateRandFile(size int) (p string, m string) {
+
+ p = filepath.Join(os.TempDir(), "fcgict"+strconv.Itoa(rand.Int()))
+
+ // open output file
+ fo, err := os.Create(p)
+ if err != nil {
+ panic(err)
+ }
+ // close fo on exit and check for its returned error
+ defer func() {
+ if err := fo.Close(); err != nil {
+ panic(err)
+ }
+ }()
+
+ h := md5.New()
+ for i := 0; i < size/16; i++ {
+ buf := make([]byte, 16)
+ binary.PutVarint(buf, rand.Int63())
+ if _, err := fo.Write(buf); err != nil {
+ log.Printf("[ERROR] failed to write buffer: %v\n", err)
+ }
+ if _, err := h.Write(buf); err != nil {
+ log.Printf("[ERROR] failed to write buffer: %v\n", err)
+ }
+ }
+ m = fmt.Sprintf("%x", h.Sum(nil))
+ return
+}
+
+func DisabledTest(t *testing.T) {
+ // TODO: test chunked reader
+ globalt = t
+
+ rand.Seed(time.Now().UTC().UnixNano())
+
+ // server
+ go func() {
+ listener, err := net.Listen("tcp", ipPort)
+ if err != nil {
+ log.Println("listener creation failed: ", err)
+ }
+
+ srv := new(FastCGIServer)
+ if err := fcgi.Serve(listener, srv); err != nil {
+ log.Print("[ERROR] failed to start server: ", err)
+ }
+ }()
+
+ time.Sleep(1 * time.Second)
+
+ // init
+ fcgiParams := make(map[string]string)
+ fcgiParams["REQUEST_METHOD"] = "GET"
+ fcgiParams["SERVER_PROTOCOL"] = "HTTP/1.1"
+ //fcgi_params["GATEWAY_INTERFACE"] = "CGI/1.1"
+ fcgiParams["SCRIPT_FILENAME"] = scriptFile
+
+ // simple GET
+ log.Println("test:", "get")
+ sendFcgi(0, fcgiParams, nil, nil, nil)
+
+ // simple post data
+ log.Println("test:", "post")
+ sendFcgi(0, fcgiParams, []byte("c4ca4238a0b923820dcc509a6f75849b=1&7b8b965ad4bca0e41ab51de7b31363a1=n"), nil, nil)
+
+ log.Println("test:", "post data (more than 60KB)")
+ data := ""
+ for i := 0x00; i < 0xff; i++ {
+ v0 := strings.Repeat(string(i), 256)
+ h := md5.New()
+ _, _ = io.WriteString(h, v0)
+ k0 := fmt.Sprintf("%x", h.Sum(nil))
+ data += k0 + "=" + url.QueryEscape(v0) + "&"
+ }
+ sendFcgi(0, fcgiParams, []byte(data), nil, nil)
+
+ log.Println("test:", "post form (use url.Values)")
+ p0 := make(map[string]string, 1)
+ p0["c4ca4238a0b923820dcc509a6f75849b"] = "1"
+ p0["7b8b965ad4bca0e41ab51de7b31363a1"] = "n"
+ sendFcgi(1, fcgiParams, nil, p0, nil)
+
+ log.Println("test:", "post forms (256 keys, more than 1MB)")
+ p1 := make(map[string]string, 1)
+ for i := 0x00; i < 0xff; i++ {
+ v0 := strings.Repeat(string(i), 4096)
+ h := md5.New()
+ _, _ = io.WriteString(h, v0)
+ k0 := fmt.Sprintf("%x", h.Sum(nil))
+ p1[k0] = v0
+ }
+ sendFcgi(1, fcgiParams, nil, p1, nil)
+
+ log.Println("test:", "post file (1 file, 500KB)) ")
+ f0 := make(map[string]string, 1)
+ path0, m0 := generateRandFile(500000)
+ f0[m0] = path0
+ sendFcgi(1, fcgiParams, nil, p1, f0)
+
+ log.Println("test:", "post multiple files (2 files, 5M each) and forms (256 keys, more than 1MB data")
+ path1, m1 := generateRandFile(5000000)
+ f0[m1] = path1
+ sendFcgi(1, fcgiParams, nil, p1, f0)
+
+ log.Println("test:", "post only files (2 files, 5M each)")
+ sendFcgi(1, fcgiParams, nil, nil, f0)
+
+ log.Println("test:", "post only 1 file")
+ delete(f0, "m0")
+ sendFcgi(1, fcgiParams, nil, nil, f0)
+
+ if err := os.Remove(path0); err != nil {
+ log.Println("[ERROR] failed to remove path: ", err)
+ }
+ if err := os.Remove(path1); err != nil {
+ log.Println("[ERROR] failed to remove path: ", err)
+ }
+}
diff --git a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go
new file mode 100644
index 0000000..91039c9
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go
@@ -0,0 +1,301 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fastcgi
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "net/http"
+ "net/url"
+ "path"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/caddyserver/caddy/v2/modules/caddyhttp"
+ "github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy"
+ "github.com/caddyserver/caddy/v2/modules/caddytls"
+
+ "github.com/caddyserver/caddy/v2"
+)
+
+func init() {
+ caddy.RegisterModule(Transport{})
+}
+
+// Transport facilitates FastCGI communication.
+type Transport struct {
+ // TODO: Populate these
+ softwareName string
+ softwareVersion string
+ serverName string
+ serverPort string
+
+ // Use this directory as the fastcgi root directory. Defaults to the root
+ // directory of the parent virtual host.
+ Root string `json:"root,omitempty"`
+
+ // The path in the URL will be split into two, with the first piece ending
+ // with the value of SplitPath. The first piece will be assumed as the
+ // actual resource (CGI script) name, and the second piece will be set to
+ // PATH_INFO for the CGI script to use.
+ SplitPath string `json:"split_path,omitempty"`
+
+ // Environment variables (TODO: make a map of string to value...?)
+ EnvVars [][2]string `json:"env,omitempty"`
+
+ // The duration used to set a deadline when connecting to an upstream.
+ DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`
+
+ // The duration used to set a deadline when reading from the FastCGI server.
+ ReadTimeout caddy.Duration `json:"read_timeout,omitempty"`
+
+ // The duration used to set a deadline when sending to the FastCGI server.
+ WriteTimeout caddy.Duration `json:"write_timeout,omitempty"`
+}
+
+// CaddyModule returns the Caddy module information.
+func (Transport) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.transport.fastcgi",
+ New: func() caddy.Module { return new(Transport) },
+ }
+}
+
+// Provision sets up t.
+func (t *Transport) Provision(_ caddy.Context) error {
+ if t.Root == "" {
+ t.Root = "{http.vars.root}"
+ }
+ return nil
+}
+
+// RoundTrip implements http.RoundTripper.
+func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
+ env, err := t.buildEnv(r)
+ if err != nil {
+ return nil, fmt.Errorf("building environment: %v", err)
+ }
+
+ // TODO: doesn't dialer have a Timeout field?
+ ctx := r.Context()
+ if t.DialTimeout > 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, time.Duration(t.DialTimeout))
+ defer cancel()
+ }
+
+ // extract dial information from request (this
+ // should embedded by the reverse proxy)
+ network, address := "tcp", r.URL.Host
+ if dialInfoVal := ctx.Value(reverseproxy.DialInfoCtxKey); dialInfoVal != nil {
+ dialInfo := dialInfoVal.(reverseproxy.DialInfo)
+ network = dialInfo.Network
+ address = dialInfo.Address
+ }
+
+ fcgiBackend, err := DialContext(ctx, network, address)
+ if err != nil {
+ return nil, fmt.Errorf("dialing backend: %v", err)
+ }
+ // fcgiBackend gets closed when response body is closed (see clientCloser)
+
+ // read/write timeouts
+ if err := fcgiBackend.SetReadTimeout(time.Duration(t.ReadTimeout)); err != nil {
+ return nil, fmt.Errorf("setting read timeout: %v", err)
+ }
+ if err := fcgiBackend.SetWriteTimeout(time.Duration(t.WriteTimeout)); err != nil {
+ return nil, fmt.Errorf("setting write timeout: %v", err)
+ }
+
+ contentLength := r.ContentLength
+ if contentLength == 0 {
+ contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
+ }
+
+ var resp *http.Response
+ switch r.Method {
+ case http.MethodHead:
+ resp, err = fcgiBackend.Head(env)
+ case http.MethodGet:
+ resp, err = fcgiBackend.Get(env, r.Body, contentLength)
+ case http.MethodOptions:
+ resp, err = fcgiBackend.Options(env)
+ default:
+ resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength)
+ }
+
+ return resp, err
+}
+
+// buildEnv returns a set of CGI environment variables for the request.
+func (t Transport) buildEnv(r *http.Request) (map[string]string, error) {
+ repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer)
+
+ var env map[string]string
+
+ // Separate remote IP and port; more lenient than net.SplitHostPort
+ var ip, port string
+ if idx := strings.LastIndex(r.RemoteAddr, ":"); idx > -1 {
+ ip = r.RemoteAddr[:idx]
+ port = r.RemoteAddr[idx+1:]
+ } else {
+ ip = r.RemoteAddr
+ }
+
+ // Remove [] from IPv6 addresses
+ ip = strings.Replace(ip, "[", "", 1)
+ ip = strings.Replace(ip, "]", "", 1)
+
+ root := repl.ReplaceAll(t.Root, ".")
+ fpath := r.URL.Path
+
+ // Split path in preparation for env variables.
+ // Previous canSplit checks ensure this can never be -1.
+ // TODO: I haven't brought over canSplit; make sure this doesn't break
+ splitPos := t.splitPos(fpath)
+
+ // Request has the extension; path was split successfully
+ docURI := fpath[:splitPos+len(t.SplitPath)]
+ pathInfo := fpath[splitPos+len(t.SplitPath):]
+ scriptName := fpath
+
+ // Strip PATH_INFO from SCRIPT_NAME
+ scriptName = strings.TrimSuffix(scriptName, pathInfo)
+
+ // SCRIPT_FILENAME is the absolute path of SCRIPT_NAME
+ scriptFilename := filepath.Join(root, scriptName)
+
+ // Add vhost path prefix to scriptName. Otherwise, some PHP software will
+ // have difficulty discovering its URL.
+ pathPrefix, _ := r.Context().Value(caddy.CtxKey("path_prefix")).(string)
+ scriptName = path.Join(pathPrefix, scriptName)
+
+ // Get the request URL from context. The context stores the original URL in case
+ // it was changed by a middleware such as rewrite. By default, we pass the
+ // original URI in as the value of REQUEST_URI (the user can overwrite this
+ // if desired). Most PHP apps seem to want the original URI. Besides, this is
+ // how nginx defaults: http://stackoverflow.com/a/12485156/1048862
+ reqURL, ok := r.Context().Value(caddyhttp.OriginalURLCtxKey).(url.URL)
+ if !ok {
+ // some requests, like active health checks, don't add this to
+ // the request context, so we can just use the current URL
+ reqURL = *r.URL
+ }
+
+ requestScheme := "http"
+ if r.TLS != nil {
+ requestScheme = "https"
+ }
+
+ // Some variables are unused but cleared explicitly to prevent
+ // the parent environment from interfering.
+ env = map[string]string{
+ // Variables defined in CGI 1.1 spec
+ "AUTH_TYPE": "", // Not used
+ "CONTENT_LENGTH": r.Header.Get("Content-Length"),
+ "CONTENT_TYPE": r.Header.Get("Content-Type"),
+ "GATEWAY_INTERFACE": "CGI/1.1",
+ "PATH_INFO": pathInfo,
+ "QUERY_STRING": r.URL.RawQuery,
+ "REMOTE_ADDR": ip,
+ "REMOTE_HOST": ip, // For speed, remote host lookups disabled
+ "REMOTE_PORT": port,
+ "REMOTE_IDENT": "", // Not used
+ "REMOTE_USER": "", // TODO: once there are authentication handlers, populate this
+ "REQUEST_METHOD": r.Method,
+ "REQUEST_SCHEME": requestScheme,
+ "SERVER_NAME": t.serverName,
+ "SERVER_PORT": t.serverPort,
+ "SERVER_PROTOCOL": r.Proto,
+ "SERVER_SOFTWARE": t.softwareName + "/" + t.softwareVersion,
+
+ // Other variables
+ "DOCUMENT_ROOT": root,
+ "DOCUMENT_URI": docURI,
+ "HTTP_HOST": r.Host, // added here, since not always part of headers
+ "REQUEST_URI": reqURL.RequestURI(),
+ "SCRIPT_FILENAME": scriptFilename,
+ "SCRIPT_NAME": scriptName,
+ }
+
+ // compliance with the CGI specification requires that
+ // PATH_TRANSLATED should only exist if PATH_INFO is defined.
+ // Info: https://www.ietf.org/rfc/rfc3875 Page 14
+ if env["PATH_INFO"] != "" {
+ env["PATH_TRANSLATED"] = filepath.Join(root, pathInfo) // Info: http://www.oreilly.com/openbook/cgi/ch02_04.html
+ }
+
+ // Some web apps rely on knowing HTTPS or not
+ if r.TLS != nil {
+ env["HTTPS"] = "on"
+ // and pass the protocol details in a manner compatible with apache's mod_ssl
+ // (which is why these have a SSL_ prefix and not TLS_).
+ v, ok := tlsProtocolStrings[r.TLS.Version]
+ if ok {
+ env["SSL_PROTOCOL"] = v
+ }
+ // and pass the cipher suite in a manner compatible with apache's mod_ssl
+ for k, v := range caddytls.SupportedCipherSuites {
+ if v == r.TLS.CipherSuite {
+ env["SSL_CIPHER"] = k
+ break
+ }
+ }
+ }
+
+ // Add env variables from config (with support for placeholders in values)
+ for _, envVar := range t.EnvVars {
+ env[envVar[0]] = repl.ReplaceAll(envVar[1], "")
+ }
+
+ // Add all HTTP headers to env variables
+ for field, val := range r.Header {
+ header := strings.ToUpper(field)
+ header = headerNameReplacer.Replace(header)
+ env["HTTP_"+header] = strings.Join(val, ", ")
+ }
+ return env, nil
+}
+
+// splitPos returns the index where path should
+// be split based on t.SplitPath.
+func (t Transport) splitPos(path string) int {
+ // TODO:
+ // if httpserver.CaseSensitivePath {
+ // return strings.Index(path, r.SplitPath)
+ // }
+ return strings.Index(strings.ToLower(path), strings.ToLower(t.SplitPath))
+}
+
+// TODO:
+// Map of supported protocols to Apache ssl_mod format
+// Note that these are slightly different from SupportedProtocols in caddytls/config.go
+var tlsProtocolStrings = map[uint16]string{
+ tls.VersionTLS10: "TLSv1",
+ tls.VersionTLS11: "TLSv1.1",
+ tls.VersionTLS12: "TLSv1.2",
+ tls.VersionTLS13: "TLSv1.3",
+}
+
+var headerNameReplacer = strings.NewReplacer(" ", "_", "-", "_")
+
+// Interface guards
+var (
+ _ caddy.Provisioner = (*Transport)(nil)
+ _ http.RoundTripper = (*Transport)(nil)
+)
diff --git a/modules/caddyhttp/reverseproxy/healthchecker.go b/modules/caddyhttp/reverseproxy/healthchecker.go
deleted file mode 100755
index c557d3f..0000000
--- a/modules/caddyhttp/reverseproxy/healthchecker.go
+++ /dev/null
@@ -1,86 +0,0 @@
-// Copyright 2015 Matthew Holt and The Caddy Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package reverseproxy
-
-import (
- "net/http"
- "time"
-)
-
-// Upstream represents the interface that must be satisfied to use the healthchecker.
-type Upstream interface {
- SetHealthiness(bool)
-}
-
-// HealthChecker represents a worker that periodically evaluates if proxy upstream host is healthy.
-type HealthChecker struct {
- upstream Upstream
- Ticker *time.Ticker
- HTTPClient *http.Client
- StopChan chan bool
-}
-
-// ScheduleChecks periodically runs health checks against an upstream host.
-func (h *HealthChecker) ScheduleChecks(url string) {
- // check if a host is healthy on start vs waiting for timer
- h.upstream.SetHealthiness(h.IsHealthy(url))
- stop := make(chan bool)
- h.StopChan = stop
-
- go func() {
- for {
- select {
- case <-h.Ticker.C:
- h.upstream.SetHealthiness(h.IsHealthy(url))
- case <-stop:
- return
- }
- }
- }()
-}
-
-// Stop stops the healthchecker from makeing further requests.
-func (h *HealthChecker) Stop() {
- h.Ticker.Stop()
- close(h.StopChan)
-}
-
-// IsHealthy attempts to check if a upstream host is healthy.
-func (h *HealthChecker) IsHealthy(url string) bool {
- req, err := http.NewRequest(http.MethodGet, url, nil)
- if err != nil {
- return false
- }
-
- resp, err := h.HTTPClient.Do(req)
- if err != nil {
- return false
- }
-
- if resp.StatusCode < 200 || resp.StatusCode >= 400 {
- return false
- }
-
- return true
-}
-
-// NewHealthCheckWorker returns a new instance of a HealthChecker.
-func NewHealthCheckWorker(u Upstream, interval time.Duration, client *http.Client) *HealthChecker {
- return &HealthChecker{
- upstream: u,
- Ticker: time.NewTicker(interval),
- HTTPClient: client,
- }
-}
diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go
new file mode 100644
index 0000000..abe0f9c
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/healthchecks.go
@@ -0,0 +1,270 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reverseproxy
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+ "net/url"
+ "regexp"
+ "strconv"
+ "time"
+
+ "github.com/caddyserver/caddy/v2"
+ "github.com/caddyserver/caddy/v2/modules/caddyhttp"
+)
+
+// HealthChecks holds configuration related to health checking.
+type HealthChecks struct {
+ Active *ActiveHealthChecks `json:"active,omitempty"`
+ Passive *PassiveHealthChecks `json:"passive,omitempty"`
+}
+
+// ActiveHealthChecks holds configuration related to active
+// health checks (that is, health checks which occur in a
+// background goroutine independently).
+type ActiveHealthChecks struct {
+ Path string `json:"path,omitempty"`
+ Port int `json:"port,omitempty"`
+ Interval caddy.Duration `json:"interval,omitempty"`
+ Timeout caddy.Duration `json:"timeout,omitempty"`
+ MaxSize int64 `json:"max_size,omitempty"`
+ ExpectStatus int `json:"expect_status,omitempty"`
+ ExpectBody string `json:"expect_body,omitempty"`
+
+ stopChan chan struct{}
+ httpClient *http.Client
+ bodyRegexp *regexp.Regexp
+}
+
+// PassiveHealthChecks holds configuration related to passive
+// health checks (that is, health checks which occur during
+// the normal flow of request proxying).
+type PassiveHealthChecks struct {
+ MaxFails int `json:"max_fails,omitempty"`
+ FailDuration caddy.Duration `json:"fail_duration,omitempty"`
+ UnhealthyRequestCount int `json:"unhealthy_request_count,omitempty"`
+ UnhealthyStatus []int `json:"unhealthy_status,omitempty"`
+ UnhealthyLatency caddy.Duration `json:"unhealthy_latency,omitempty"`
+}
+
+// CircuitBreaker is a type that can act as an early-warning
+// system for the health checker when backends are getting
+// overloaded.
+type CircuitBreaker interface {
+ OK() bool
+ RecordMetric(statusCode int, latency time.Duration)
+}
+
+// activeHealthChecker runs active health checks on a
+// regular basis and blocks until
+// h.HealthChecks.Active.stopChan is closed.
+func (h *Handler) activeHealthChecker() {
+ ticker := time.NewTicker(time.Duration(h.HealthChecks.Active.Interval))
+ h.doActiveHealthChecksForAllHosts()
+ for {
+ select {
+ case <-ticker.C:
+ h.doActiveHealthChecksForAllHosts()
+ case <-h.HealthChecks.Active.stopChan:
+ ticker.Stop()
+ return
+ }
+ }
+}
+
+// doActiveHealthChecksForAllHosts immediately performs a
+// health checks for all hosts in the global repository.
+func (h *Handler) doActiveHealthChecksForAllHosts() {
+ hosts.Range(func(key, value interface{}) bool {
+ networkAddr := key.(string)
+ host := value.(Host)
+
+ go func(networkAddr string, host Host) {
+ network, addrs, err := caddy.ParseNetworkAddress(networkAddr)
+ if err != nil {
+ log.Printf("[ERROR] reverse_proxy: active health check for host %s: bad network address: %v", networkAddr, err)
+ return
+ }
+ if len(addrs) != 1 {
+ log.Printf("[ERROR] reverse_proxy: active health check for host %s: multiple addresses (upstream must map to only one address)", networkAddr)
+ return
+ }
+ hostAddr := addrs[0]
+ if network == "unix" || network == "unixgram" || network == "unixpacket" {
+ // this will be used as the Host portion of a http.Request URL, and
+ // paths to socket files would produce an error when creating URL,
+ // so use a fake Host value instead
+ hostAddr = network
+ }
+ err = h.doActiveHealthCheck(DialInfo{network, addrs[0]}, hostAddr, host)
+ if err != nil {
+ log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", networkAddr, err)
+ }
+ }(networkAddr, host)
+
+ // continue to iterate all hosts
+ return true
+ })
+}
+
+// doActiveHealthCheck performs a health check to host which
+// can be reached at address hostAddr. The actual address for
+// the request will be built according to active health checker
+// config. The health status of the host will be updated
+// according to whether it passes the health check. An error is
+// returned only if the health check fails to occur or if marking
+// the host's health status fails.
+func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host Host) error {
+ // create the URL for the request that acts as a health check
+ scheme := "http"
+ if ht, ok := h.Transport.(*http.Transport); ok && ht.TLSClientConfig != nil {
+ // this is kind of a hacky way to know if we should use HTTPS, but whatever
+ scheme = "https"
+ }
+ u := &url.URL{
+ Scheme: scheme,
+ Host: hostAddr,
+ Path: h.HealthChecks.Active.Path,
+ }
+
+ // adjust the port, if configured to be different
+ if h.HealthChecks.Active.Port != 0 {
+ portStr := strconv.Itoa(h.HealthChecks.Active.Port)
+ host, _, err := net.SplitHostPort(hostAddr)
+ if err != nil {
+ host = hostAddr
+ }
+ u.Host = net.JoinHostPort(host, portStr)
+ }
+
+ // attach dialing information to this request
+ ctx := context.Background()
+ ctx = context.WithValue(ctx, caddy.ReplacerCtxKey, caddy.NewReplacer())
+ ctx = context.WithValue(ctx, DialInfoCtxKey, dialInfo)
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
+ if err != nil {
+ return fmt.Errorf("making request: %v", err)
+ }
+
+ // do the request, being careful to tame the response body
+ resp, err := h.HealthChecks.Active.httpClient.Do(req)
+ if err != nil {
+ log.Printf("[INFO] reverse_proxy: active health check: %s is down (HTTP request failed: %v)", hostAddr, err)
+ _, err2 := host.SetHealthy(false)
+ if err2 != nil {
+ return fmt.Errorf("marking unhealthy: %v", err2)
+ }
+ return nil
+ }
+ var body io.Reader = resp.Body
+ if h.HealthChecks.Active.MaxSize > 0 {
+ body = io.LimitReader(body, h.HealthChecks.Active.MaxSize)
+ }
+ defer func() {
+ // drain any remaining body so connection could be re-used
+ io.Copy(ioutil.Discard, body)
+ resp.Body.Close()
+ }()
+
+ // if status code is outside criteria, mark down
+ if h.HealthChecks.Active.ExpectStatus > 0 {
+ if !caddyhttp.StatusCodeMatches(resp.StatusCode, h.HealthChecks.Active.ExpectStatus) {
+ log.Printf("[INFO] reverse_proxy: active health check: %s is down (status code %d unexpected)", hostAddr, resp.StatusCode)
+ _, err := host.SetHealthy(false)
+ if err != nil {
+ return fmt.Errorf("marking unhealthy: %v", err)
+ }
+ return nil
+ }
+ } else if resp.StatusCode < 200 || resp.StatusCode >= 400 {
+ log.Printf("[INFO] reverse_proxy: active health check: %s is down (status code %d out of tolerances)", hostAddr, resp.StatusCode)
+ _, err := host.SetHealthy(false)
+ if err != nil {
+ return fmt.Errorf("marking unhealthy: %v", err)
+ }
+ return nil
+ }
+
+ // if body does not match regex, mark down
+ if h.HealthChecks.Active.bodyRegexp != nil {
+ bodyBytes, err := ioutil.ReadAll(body)
+ if err != nil {
+ log.Printf("[INFO] reverse_proxy: active health check: %s is down (failed to read response body)", hostAddr)
+ _, err := host.SetHealthy(false)
+ if err != nil {
+ return fmt.Errorf("marking unhealthy: %v", err)
+ }
+ return nil
+ }
+ if !h.HealthChecks.Active.bodyRegexp.Match(bodyBytes) {
+ log.Printf("[INFO] reverse_proxy: active health check: %s is down (response body failed expectations)", hostAddr)
+ _, err := host.SetHealthy(false)
+ if err != nil {
+ return fmt.Errorf("marking unhealthy: %v", err)
+ }
+ return nil
+ }
+ }
+
+ // passed health check parameters, so mark as healthy
+ swapped, err := host.SetHealthy(true)
+ if swapped {
+ log.Printf("[INFO] reverse_proxy: active health check: %s is back up", hostAddr)
+ }
+ if err != nil {
+ return fmt.Errorf("marking healthy: %v", err)
+ }
+
+ return nil
+}
+
+// countFailure is used with passive health checks. It
+// remembers 1 failure for upstream for the configured
+// duration. If passive health checks are disabled or
+// failure expiry is 0, this is a no-op.
+func (h *Handler) countFailure(upstream *Upstream) {
+ // only count failures if passive health checking is enabled
+ // and if failures are configured have a non-zero expiry
+ if h.HealthChecks == nil || h.HealthChecks.Passive == nil {
+ return
+ }
+ failDuration := time.Duration(h.HealthChecks.Passive.FailDuration)
+ if failDuration == 0 {
+ return
+ }
+
+ // count failure immediately
+ err := upstream.Host.CountFail(1)
+ if err != nil {
+ log.Printf("[ERROR] proxy: upstream %s: counting failure: %v",
+ upstream.dialInfo, err)
+ }
+
+ // forget it later
+ go func(host Host, failDuration time.Duration) {
+ time.Sleep(failDuration)
+ err := host.CountFail(-1)
+ if err != nil {
+ log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v",
+ upstream.dialInfo, err)
+ }
+ }(upstream.Host, failDuration)
+}
diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go
new file mode 100644
index 0000000..1c0fae3
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/hosts.go
@@ -0,0 +1,193 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reverseproxy
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "github.com/caddyserver/caddy/v2"
+)
+
+// Host represents a remote host which can be proxied to.
+// Its methods must be safe for concurrent use.
+type Host interface {
+ // NumRequests returns the numnber of requests
+ // currently in process with the host.
+ NumRequests() int
+
+ // Fails returns the count of recent failures.
+ Fails() int
+
+ // Unhealthy returns true if the backend is unhealthy.
+ Unhealthy() bool
+
+ // CountRequest atomically counts the given number of
+ // requests as currently in process with the host. The
+ // count should not go below 0.
+ CountRequest(int) error
+
+ // CountFail atomically counts the given number of
+ // failures with the host. The count should not go
+ // below 0.
+ CountFail(int) error
+
+ // SetHealthy atomically marks the host as either
+ // healthy (true) or unhealthy (false). If the given
+ // status is the same, this should be a no-op and
+ // return false. It returns true if the status was
+ // changed; i.e. if it is now different from before.
+ SetHealthy(bool) (bool, error)
+}
+
+// UpstreamPool is a collection of upstreams.
+type UpstreamPool []*Upstream
+
+// Upstream bridges this proxy's configuration to the
+// state of the backend host it is correlated with.
+type Upstream struct {
+ Host `json:"-"`
+
+ Dial string `json:"dial,omitempty"`
+ MaxRequests int `json:"max_requests,omitempty"`
+
+ // TODO: This could be really useful, to bind requests
+ // with certain properties to specific backends
+ // HeaderAffinity string
+ // IPAffinity string
+
+ healthCheckPolicy *PassiveHealthChecks
+ cb CircuitBreaker
+ dialInfo DialInfo
+}
+
+// Available returns true if the remote host
+// is available to receive requests. This is
+// the method that should be used by selection
+// policies, etc. to determine if a backend
+// should be able to be sent a request.
+func (u *Upstream) Available() bool {
+ return u.Healthy() && !u.Full()
+}
+
+// Healthy returns true if the remote host
+// is currently known to be healthy or "up".
+// It consults the circuit breaker, if any.
+func (u *Upstream) Healthy() bool {
+ healthy := !u.Host.Unhealthy()
+ if healthy && u.healthCheckPolicy != nil {
+ healthy = u.Host.Fails() < u.healthCheckPolicy.MaxFails
+ }
+ if healthy && u.cb != nil {
+ healthy = u.cb.OK()
+ }
+ return healthy
+}
+
+// Full returns true if the remote host
+// cannot receive more requests at this time.
+func (u *Upstream) Full() bool {
+ return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests
+}
+
+// upstreamHost is the basic, in-memory representation
+// of the state of a remote host. It implements the
+// Host interface.
+type upstreamHost struct {
+ numRequests int64 // must be first field to be 64-bit aligned on 32-bit systems (see https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
+ fails int64
+ unhealthy int32
+}
+
+// NumRequests returns the number of active requests to the upstream.
+func (uh *upstreamHost) NumRequests() int {
+ return int(atomic.LoadInt64(&uh.numRequests))
+}
+
+// Fails returns the number of recent failures with the upstream.
+func (uh *upstreamHost) Fails() int {
+ return int(atomic.LoadInt64(&uh.fails))
+}
+
+// Unhealthy returns whether the upstream is healthy.
+func (uh *upstreamHost) Unhealthy() bool {
+ return atomic.LoadInt32(&uh.unhealthy) == 1
+}
+
+// CountRequest mutates the active request count by
+// delta. It returns an error if the adjustment fails.
+func (uh *upstreamHost) CountRequest(delta int) error {
+ result := atomic.AddInt64(&uh.numRequests, int64(delta))
+ if result < 0 {
+ return fmt.Errorf("count below 0: %d", result)
+ }
+ return nil
+}
+
+// CountFail mutates the recent failures count by
+// delta. It returns an error if the adjustment fails.
+func (uh *upstreamHost) CountFail(delta int) error {
+ result := atomic.AddInt64(&uh.fails, int64(delta))
+ if result < 0 {
+ return fmt.Errorf("count below 0: %d", result)
+ }
+ return nil
+}
+
+// SetHealthy sets the upstream has healthy or unhealthy
+// and returns true if the value was different from before,
+// or an error if the adjustment failed.
+func (uh *upstreamHost) SetHealthy(healthy bool) (bool, error) {
+ var unhealthy, compare int32 = 1, 0
+ if healthy {
+ unhealthy, compare = 0, 1
+ }
+ swapped := atomic.CompareAndSwapInt32(&uh.unhealthy, compare, unhealthy)
+ return swapped, nil
+}
+
+// DialInfo contains information needed to dial a
+// connection to an upstream host. This information
+// may be different than that which is represented
+// in a URL (for example, unix sockets don't have
+// a host that can be represented in a URL, but
+// they certainly have a network name and address).
+type DialInfo struct {
+ // The network to use. This should be one of the
+ // values that is accepted by net.Dial:
+ // https://golang.org/pkg/net/#Dial
+ Network string
+
+ // The address to dial. Follows the same
+ // semantics and rules as net.Dial.
+ Address string
+}
+
+// String returns the Caddy network address form
+// by joining the network and address with a
+// forward slash.
+func (di DialInfo) String() string {
+ return di.Network + "/" + di.Address
+}
+
+// DialInfoCtxKey is used to store a DialInfo
+// in a context.Context.
+const DialInfoCtxKey = caddy.CtxKey("dial_info")
+
+// hosts is the global repository for hosts that are
+// currently in use by active configuration(s). This
+// allows the state of remote hosts to be preserved
+// through config reloads.
+var hosts = caddy.NewUsagePool()
diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go
new file mode 100644
index 0000000..c135ac8
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/httptransport.go
@@ -0,0 +1,208 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reverseproxy
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/base64"
+ "fmt"
+ "net"
+ "net/http"
+ "reflect"
+ "time"
+
+ "github.com/caddyserver/caddy/v2"
+)
+
+func init() {
+ caddy.RegisterModule(HTTPTransport{})
+}
+
+// HTTPTransport is essentially a configuration wrapper for http.Transport.
+// It defines a JSON structure useful when configuring the HTTP transport
+// for Caddy's reverse proxy.
+type HTTPTransport struct {
+ // TODO: It's possible that other transports (like fastcgi) might be
+ // able to borrow/use at least some of these config fields; if so,
+ // move them into a type called CommonTransport and embed it
+ TLS *TLSConfig `json:"tls,omitempty"`
+ KeepAlive *KeepAlive `json:"keep_alive,omitempty"`
+ Compression *bool `json:"compression,omitempty"`
+ MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // TODO: NOTE: we use our health check stuff to enforce max REQUESTS per host, but this is connections
+ DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`
+ FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"`
+ ResponseHeaderTimeout caddy.Duration `json:"response_header_timeout,omitempty"`
+ ExpectContinueTimeout caddy.Duration `json:"expect_continue_timeout,omitempty"`
+ MaxResponseHeaderSize int64 `json:"max_response_header_size,omitempty"`
+ WriteBufferSize int `json:"write_buffer_size,omitempty"`
+ ReadBufferSize int `json:"read_buffer_size,omitempty"`
+
+ RoundTripper http.RoundTripper `json:"-"`
+}
+
+// CaddyModule returns the Caddy module information.
+func (HTTPTransport) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.transport.http",
+ New: func() caddy.Module { return new(HTTPTransport) },
+ }
+}
+
+// Provision sets up h.RoundTripper with a http.Transport
+// that is ready to use.
+func (h *HTTPTransport) Provision(_ caddy.Context) error {
+ dialer := &net.Dialer{
+ Timeout: time.Duration(h.DialTimeout),
+ FallbackDelay: time.Duration(h.FallbackDelay),
+ // TODO: Resolver
+ }
+
+ rt := &http.Transport{
+ DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
+ // the proper dialing information should be embedded into the request's context
+ if dialInfoVal := ctx.Value(DialInfoCtxKey); dialInfoVal != nil {
+ dialInfo := dialInfoVal.(DialInfo)
+ network = dialInfo.Network
+ address = dialInfo.Address
+ }
+ return dialer.DialContext(ctx, network, address)
+ },
+ MaxConnsPerHost: h.MaxConnsPerHost,
+ ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
+ ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout),
+ MaxResponseHeaderBytes: h.MaxResponseHeaderSize,
+ WriteBufferSize: h.WriteBufferSize,
+ ReadBufferSize: h.ReadBufferSize,
+ }
+
+ if h.TLS != nil {
+ rt.TLSHandshakeTimeout = time.Duration(h.TLS.HandshakeTimeout)
+
+ var err error
+ rt.TLSClientConfig, err = h.TLS.MakeTLSClientConfig()
+ if err != nil {
+ return fmt.Errorf("making TLS client config: %v", err)
+ }
+ }
+
+ if h.KeepAlive != nil {
+ dialer.KeepAlive = time.Duration(h.KeepAlive.ProbeInterval)
+ if enabled := h.KeepAlive.Enabled; enabled != nil {
+ rt.DisableKeepAlives = !*enabled
+ }
+ rt.MaxIdleConns = h.KeepAlive.MaxIdleConns
+ rt.MaxIdleConnsPerHost = h.KeepAlive.MaxIdleConnsPerHost
+ rt.IdleConnTimeout = time.Duration(h.KeepAlive.IdleConnTimeout)
+ }
+
+ if h.Compression != nil {
+ rt.DisableCompression = !*h.Compression
+ }
+
+ h.RoundTripper = rt
+
+ return nil
+}
+
+// RoundTrip implements http.RoundTripper with h.RoundTripper.
+func (h HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ return h.RoundTripper.RoundTrip(req)
+}
+
+// TLSConfig holds configuration related to the
+// TLS configuration for the transport/client.
+type TLSConfig struct {
+ RootCAPool []string `json:"root_ca_pool,omitempty"`
+ // TODO: Should the client cert+key config use caddytls.CertificateLoader modules?
+ ClientCertificateFile string `json:"client_certificate_file,omitempty"`
+ ClientCertificateKeyFile string `json:"client_certificate_key_file,omitempty"`
+ InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"`
+ HandshakeTimeout caddy.Duration `json:"handshake_timeout,omitempty"`
+}
+
+// MakeTLSClientConfig returns a tls.Config usable by a client to a backend.
+// If there is no custom TLS configuration, a nil config may be returned.
+func (t TLSConfig) MakeTLSClientConfig() (*tls.Config, error) {
+ cfg := new(tls.Config)
+
+ // client auth
+ if t.ClientCertificateFile != "" && t.ClientCertificateKeyFile == "" {
+ return nil, fmt.Errorf("client_certificate_file specified without client_certificate_key_file")
+ }
+ if t.ClientCertificateFile == "" && t.ClientCertificateKeyFile != "" {
+ return nil, fmt.Errorf("client_certificate_key_file specified without client_certificate_file")
+ }
+ if t.ClientCertificateFile != "" && t.ClientCertificateKeyFile != "" {
+ cert, err := tls.LoadX509KeyPair(t.ClientCertificateFile, t.ClientCertificateKeyFile)
+ if err != nil {
+ return nil, fmt.Errorf("loading client certificate key pair: %v", err)
+ }
+ cfg.Certificates = []tls.Certificate{cert}
+ }
+
+ // trusted root CAs
+ if len(t.RootCAPool) > 0 {
+ rootPool := x509.NewCertPool()
+ for _, encodedCACert := range t.RootCAPool {
+ caCert, err := decodeBase64DERCert(encodedCACert)
+ if err != nil {
+ return nil, fmt.Errorf("parsing CA certificate: %v", err)
+ }
+ rootPool.AddCert(caCert)
+ }
+ cfg.RootCAs = rootPool
+ }
+
+ // throw all security out the window
+ cfg.InsecureSkipVerify = t.InsecureSkipVerify
+
+ // only return a config if it's not empty
+ if reflect.DeepEqual(cfg, new(tls.Config)) {
+ return nil, nil
+ }
+
+ cfg.NextProtos = []string{"h2", "http/1.1"} // TODO: ensure that this actually enables HTTP/2
+
+ return cfg, nil
+}
+
+// decodeBase64DERCert base64-decodes, then DER-decodes, certStr.
+func decodeBase64DERCert(certStr string) (*x509.Certificate, error) {
+ // decode base64
+ derBytes, err := base64.StdEncoding.DecodeString(certStr)
+ if err != nil {
+ return nil, err
+ }
+
+ // parse the DER-encoded certificate
+ return x509.ParseCertificate(derBytes)
+}
+
+// KeepAlive holds configuration pertaining to HTTP Keep-Alive.
+type KeepAlive struct {
+ Enabled *bool `json:"enabled,omitempty"`
+ ProbeInterval caddy.Duration `json:"probe_interval,omitempty"`
+ MaxIdleConns int `json:"max_idle_conns,omitempty"`
+ MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"`
+ IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle
+}
+
+// Interface guards
+var (
+ _ caddy.Provisioner = (*HTTPTransport)(nil)
+ _ http.RoundTripper = (*HTTPTransport)(nil)
+)
diff --git a/modules/caddyhttp/reverseproxy/module.go b/modules/caddyhttp/reverseproxy/module.go
deleted file mode 100755
index 21aca1d..0000000
--- a/modules/caddyhttp/reverseproxy/module.go
+++ /dev/null
@@ -1,53 +0,0 @@
-// Copyright 2015 Matthew Holt and The Caddy Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package reverseproxy
-
-import (
- "github.com/caddyserver/caddy/v2"
- "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
- "github.com/caddyserver/caddy/v2/modules/caddyhttp"
-)
-
-func init() {
- caddy.RegisterModule(new(LoadBalanced))
- httpcaddyfile.RegisterHandlerDirective("reverse_proxy", parseCaddyfile) // TODO: "proxy"?
-}
-
-// CaddyModule returns the Caddy module information.
-func (*LoadBalanced) CaddyModule() caddy.ModuleInfo {
- return caddy.ModuleInfo{
- Name: "http.handlers.reverse_proxy",
- New: func() caddy.Module { return new(LoadBalanced) },
- }
-}
-
-// parseCaddyfile sets up the handler from Caddyfile tokens. Syntax:
-//
-// proxy [<matcher>] <to>
-//
-// TODO: This needs to be finished. It definitely needs to be able to open a block...
-func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
- lb := new(LoadBalanced)
- for h.Next() {
- allTo := h.RemainingArgs()
- if len(allTo) == 0 {
- return nil, h.ArgErr()
- }
- for _, to := range allTo {
- lb.Upstreams = append(lb.Upstreams, &UpstreamConfig{Host: to})
- }
- }
- return lb, nil
-}
diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go
index 68393de..5a37613 100755..100644
--- a/modules/caddyhttp/reverseproxy/reverseproxy.go
+++ b/modules/caddyhttp/reverseproxy/reverseproxy.go
@@ -16,227 +16,304 @@ package reverseproxy
import (
"context"
+ "encoding/json"
"fmt"
- "io"
- "log"
"net"
"net/http"
- "net/url"
+ "regexp"
"strings"
- "sync"
"time"
+ "github.com/caddyserver/caddy/v2"
+ "github.com/caddyserver/caddy/v2/modules/caddyhttp"
"golang.org/x/net/http/httpguts"
)
-// ReverseProxy is an HTTP Handler that takes an incoming request and
-// sends it to another server, proxying the response back to the
-// client.
-type ReverseProxy struct {
- // Director must be a function which modifies
- // the request into a new request to be sent
- // using Transport. Its response is then copied
- // back to the original client unmodified.
- // Director must not access the provided Request
- // after returning.
- Director func(*http.Request)
-
- // The transport used to perform proxy requests.
- // If nil, http.DefaultTransport is used.
- Transport http.RoundTripper
-
- // FlushInterval specifies the flush interval
- // to flush to the client while copying the
- // response body.
- // If zero, no periodic flushing is done.
- // A negative value means to flush immediately
- // after each write to the client.
- // The FlushInterval is ignored when ReverseProxy
- // recognizes a response as a streaming response;
- // for such responses, writes are flushed to the client
- // immediately.
- FlushInterval time.Duration
-
- // ErrorLog specifies an optional logger for errors
- // that occur when attempting to proxy the request.
- // If nil, logging goes to os.Stderr via the log package's
- // standard logger.
- ErrorLog *log.Logger
-
- // BufferPool optionally specifies a buffer pool to
- // get byte slices for use by io.CopyBuffer when
- // copying HTTP response bodies.
- BufferPool BufferPool
-
- // ModifyResponse is an optional function that modifies the
- // Response from the backend. It is called if the backend
- // returns a response at all, with any HTTP status code.
- // If the backend is unreachable, the optional ErrorHandler is
- // called without any call to ModifyResponse.
- //
- // If ModifyResponse returns an error, ErrorHandler is called
- // with its error value. If ErrorHandler is nil, its default
- // implementation is used.
- ModifyResponse func(*http.Response) error
-
- // ErrorHandler is an optional function that handles errors
- // reaching the backend or errors from ModifyResponse.
- //
- // If nil, the default is to log the provided error and return
- // a 502 Status Bad Gateway response.
- ErrorHandler func(http.ResponseWriter, *http.Request, error)
+func init() {
+ caddy.RegisterModule(Handler{})
}
-// A BufferPool is an interface for getting and returning temporary
-// byte slices for use by io.CopyBuffer.
-type BufferPool interface {
- Get() []byte
- Put([]byte)
+// Handler implements a highly configurable and production-ready reverse proxy.
+type Handler struct {
+ TransportRaw json.RawMessage `json:"transport,omitempty"`
+ CBRaw json.RawMessage `json:"circuit_breaker,omitempty"`
+ LoadBalancing *LoadBalancing `json:"load_balancing,omitempty"`
+ HealthChecks *HealthChecks `json:"health_checks,omitempty"`
+ Upstreams UpstreamPool `json:"upstreams,omitempty"`
+ FlushInterval caddy.Duration `json:"flush_interval,omitempty"`
+
+ Transport http.RoundTripper `json:"-"`
+ CB CircuitBreaker `json:"-"`
}
-func singleJoiningSlash(a, b string) string {
- aslash := strings.HasSuffix(a, "/")
- bslash := strings.HasPrefix(b, "/")
- switch {
- case aslash && bslash:
- return a + b[1:]
- case !aslash && !bslash:
- return a + "/" + b
+// CaddyModule returns the Caddy module information.
+func (Handler) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy",
+ New: func() caddy.Module { return new(Handler) },
}
- return a + b
}
-// NewSingleHostReverseProxy returns a new ReverseProxy that routes
-// URLs to the scheme, host, and base path provided in target. If the
-// target's path is "/base" and the incoming request was for "/dir",
-// the target request will be for /base/dir.
-// NewSingleHostReverseProxy does not rewrite the Host header.
-// To rewrite Host headers, use ReverseProxy directly with a custom
-// Director policy.
-func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
- targetQuery := target.RawQuery
- director := func(req *http.Request) {
- req.URL.Scheme = target.Scheme
- req.URL.Host = target.Host
- req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
- if targetQuery == "" || req.URL.RawQuery == "" {
- req.URL.RawQuery = targetQuery + req.URL.RawQuery
- } else {
- req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
+// Provision ensures that h is set up properly before use.
+func (h *Handler) Provision(ctx caddy.Context) error {
+ // start by loading modules
+ if h.TransportRaw != nil {
+ val, err := ctx.LoadModuleInline("protocol", "http.handlers.reverse_proxy.transport", h.TransportRaw)
+ if err != nil {
+ return fmt.Errorf("loading transport module: %s", err)
}
- if _, ok := req.Header["User-Agent"]; !ok {
- // explicitly disable User-Agent so it's not set to default value
- req.Header.Set("User-Agent", "")
+ h.Transport = val.(http.RoundTripper)
+ h.TransportRaw = nil // allow GC to deallocate - TODO: Does this help?
+ }
+ if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil {
+ val, err := ctx.LoadModuleInline("policy",
+ "http.handlers.reverse_proxy.selection_policies",
+ h.LoadBalancing.SelectionPolicyRaw)
+ if err != nil {
+ return fmt.Errorf("loading load balancing selection module: %s", err)
}
+ h.LoadBalancing.SelectionPolicy = val.(Selector)
+ h.LoadBalancing.SelectionPolicyRaw = nil // allow GC to deallocate - TODO: Does this help?
+ }
+ if h.CBRaw != nil {
+ val, err := ctx.LoadModuleInline("type", "http.handlers.reverse_proxy.circuit_breakers", h.CBRaw)
+ if err != nil {
+ return fmt.Errorf("loading circuit breaker module: %s", err)
+ }
+ h.CB = val.(CircuitBreaker)
+ h.CBRaw = nil // allow GC to deallocate - TODO: Does this help?
}
- return &ReverseProxy{Director: director}
-}
-func copyHeader(dst, src http.Header) {
- for k, vv := range src {
- for _, v := range vv {
- dst.Add(k, v)
+ if h.Transport == nil {
+ t := &HTTPTransport{
+ KeepAlive: &KeepAlive{
+ ProbeInterval: caddy.Duration(30 * time.Second),
+ IdleConnTimeout: caddy.Duration(2 * time.Minute),
+ },
+ DialTimeout: caddy.Duration(10 * time.Second),
}
+ err := t.Provision(ctx)
+ if err != nil {
+ return fmt.Errorf("provisioning default transport: %v", err)
+ }
+ h.Transport = t
}
-}
-func cloneHeader(h http.Header) http.Header {
- h2 := make(http.Header, len(h))
- for k, vv := range h {
- vv2 := make([]string, len(vv))
- copy(vv2, vv)
- h2[k] = vv2
+ if h.LoadBalancing == nil {
+ h.LoadBalancing = new(LoadBalancing)
+ }
+ if h.LoadBalancing.SelectionPolicy == nil {
+ h.LoadBalancing.SelectionPolicy = RandomSelection{}
+ }
+ if h.LoadBalancing.TryDuration > 0 && h.LoadBalancing.TryInterval == 0 {
+ // a non-zero try_duration with a zero try_interval
+ // will always spin the CPU for try_duration if the
+ // upstream is local or low-latency; avoid that by
+ // defaulting to a sane wait period between attempts
+ h.LoadBalancing.TryInterval = caddy.Duration(250 * time.Millisecond)
}
- return h2
-}
-// Hop-by-hop headers. These are removed when sent to the backend.
-// As of RFC 7230, hop-by-hop headers are required to appear in the
-// Connection header field. These are the headers defined by the
-// obsoleted RFC 2616 (section 13.5.1) and are used for backward
-// compatibility.
-var hopHeaders = []string{
- "Connection",
- "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
- "Keep-Alive",
- "Proxy-Authenticate",
- "Proxy-Authorization",
- "Te", // canonicalized version of "TE"
- "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
- "Transfer-Encoding",
- "Upgrade",
-}
+ // if active health checks are enabled, configure them and start a worker
+ if h.HealthChecks != nil &&
+ h.HealthChecks.Active != nil &&
+ (h.HealthChecks.Active.Path != "" || h.HealthChecks.Active.Port != 0) {
+ timeout := time.Duration(h.HealthChecks.Active.Timeout)
+ if timeout == 0 {
+ timeout = 10 * time.Second
+ }
-func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
- p.logf("http: proxy error: %v", err)
- rw.WriteHeader(http.StatusBadGateway)
-}
+ h.HealthChecks.Active.stopChan = make(chan struct{})
+ h.HealthChecks.Active.httpClient = &http.Client{
+ Timeout: timeout,
+ Transport: h.Transport,
+ }
+
+ if h.HealthChecks.Active.Interval == 0 {
+ h.HealthChecks.Active.Interval = caddy.Duration(30 * time.Second)
+ }
+
+ if h.HealthChecks.Active.ExpectBody != "" {
+ var err error
+ h.HealthChecks.Active.bodyRegexp, err = regexp.Compile(h.HealthChecks.Active.ExpectBody)
+ if err != nil {
+ return fmt.Errorf("expect_body: compiling regular expression: %v", err)
+ }
+ }
+
+ go h.activeHealthChecker()
+ }
-func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
- if p.ErrorHandler != nil {
- return p.ErrorHandler
+ var allUpstreams []*Upstream
+ for _, upstream := range h.Upstreams {
+ // upstreams are allowed to map to only a single host,
+ // but an upstream's address may semantically represent
+ // multiple addresses, so make sure to handle each
+ // one in turn based on this one upstream config
+ network, addresses, err := caddy.ParseNetworkAddress(upstream.Dial)
+ if err != nil {
+ return fmt.Errorf("parsing dial address: %v", err)
+ }
+
+ for _, addr := range addresses {
+ // make a new upstream based on the original
+ // that has a singular dial address
+ upstreamCopy := *upstream
+ upstreamCopy.dialInfo = DialInfo{network, addr}
+ upstreamCopy.Dial = upstreamCopy.dialInfo.String()
+ upstreamCopy.cb = h.CB
+
+ // if host already exists from a current config,
+ // use that instead; otherwise, add it
+ // TODO: make hosts modular, so that their state can be distributed in enterprise for example
+ // TODO: If distributed, the pool should be stored in storage...
+ var host Host = new(upstreamHost)
+ activeHost, loaded := hosts.LoadOrStore(upstreamCopy.Dial, host)
+ if loaded {
+ host = activeHost.(Host)
+ }
+ upstreamCopy.Host = host
+
+ // if the passive health checker has a non-zero "unhealthy
+ // request count" but the upstream has no MaxRequests set
+ // (they are the same thing, but one is a default value for
+ // for upstreams with a zero MaxRequests), copy the default
+ // value into this upstream, since the value in the upstream
+ // (MaxRequests) is what is used during availability checks
+ if h.HealthChecks != nil &&
+ h.HealthChecks.Passive != nil &&
+ h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
+ upstreamCopy.MaxRequests == 0 {
+ upstreamCopy.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
+ }
+
+ // upstreams need independent access to the passive
+ // health check policy because they run outside of the
+ // scope of a request handler
+ if h.HealthChecks != nil {
+ upstreamCopy.healthCheckPolicy = h.HealthChecks.Passive
+ }
+
+ allUpstreams = append(allUpstreams, &upstreamCopy)
+ }
}
- return p.defaultErrorHandler
+
+ // replace the unmarshaled upstreams (possible 1:many
+ // address mapping) with our list, which is mapped 1:1,
+ // thus may have expanded the original list
+ h.Upstreams = allUpstreams
+
+ return nil
}
-// modifyResponse conditionally runs the optional ModifyResponse hook
-// and reports whether the request should proceed.
-func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
- if p.ModifyResponse == nil {
- return true
+// Cleanup cleans up the resources made by h during provisioning.
+func (h *Handler) Cleanup() error {
+ // stop the active health checker
+ if h.HealthChecks != nil &&
+ h.HealthChecks.Active != nil &&
+ h.HealthChecks.Active.stopChan != nil {
+ close(h.HealthChecks.Active.stopChan)
}
- if err := p.ModifyResponse(res); err != nil {
- res.Body.Close()
- p.getErrorHandler()(rw, req, err)
- return false
+
+ // remove hosts from our config from the pool
+ for _, upstream := range h.Upstreams {
+ hosts.Delete(upstream.dialInfo.String())
}
- return true
+
+ return nil
}
-func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*http.Response, error) {
- transport := p.Transport
- if transport == nil {
- transport = http.DefaultTransport
- }
-
- ctx := req.Context()
- if cn, ok := rw.(http.CloseNotifier); ok {
- var cancel context.CancelFunc
- ctx, cancel = context.WithCancel(ctx)
- defer cancel()
- notifyChan := cn.CloseNotify()
- go func() {
- select {
- case <-notifyChan:
- cancel()
- case <-ctx.Done():
+func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
+ // prepare the request for proxying; this is needed only once
+ err := h.prepareRequest(r)
+ if err != nil {
+ return caddyhttp.Error(http.StatusInternalServerError,
+ fmt.Errorf("preparing request for upstream round-trip: %v", err))
+ }
+
+ start := time.Now()
+
+ var proxyErr error
+ for {
+ // choose an available upstream
+ upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, r)
+ if upstream == nil {
+ if proxyErr == nil {
+ proxyErr = fmt.Errorf("no available upstreams")
+ }
+ if !h.tryAgain(start, proxyErr) {
+ break
}
- }()
+ continue
+ }
+
+ // attach to the request information about how to dial the upstream;
+ // this is necessary because the information cannot be sufficiently
+ // or satisfactorily represented in a URL
+ ctx := context.WithValue(r.Context(), DialInfoCtxKey, upstream.dialInfo)
+ r = r.WithContext(ctx)
+
+ // proxy the request to that upstream
+ proxyErr = h.reverseProxy(w, r, upstream)
+ if proxyErr == nil || proxyErr == context.Canceled {
+ // context.Canceled happens when the downstream client
+ // cancels the request; we don't have to worry about that
+ return nil
+ }
+
+ // remember this failure (if enabled)
+ h.countFailure(upstream)
+
+ // if we've tried long enough, break
+ if !h.tryAgain(start, proxyErr) {
+ break
+ }
+ }
+
+ return caddyhttp.Error(http.StatusBadGateway, proxyErr)
+}
+
+// prepareRequest modifies req so that it is ready to be proxied,
+// except for directing to a specific upstream. This method mutates
+// headers and other necessary properties of the request and should
+// be done just once (before proxying) regardless of proxy retries.
+// This assumes that no mutations of the request are performed
+// by h during or after proxying.
+func (h Handler) prepareRequest(req *http.Request) error {
+ // as a special (but very common) case, if the transport
+ // is HTTP, then ensure the request has the proper scheme
+ // because incoming requests by default are lacking it
+ if req.URL.Scheme == "" {
+ req.URL.Scheme = "http"
+ if ht, ok := h.Transport.(*HTTPTransport); ok && ht.TLS != nil {
+ req.URL.Scheme = "https"
+ }
}
- outreq := req.WithContext(ctx) // includes shallow copies of maps, but okay
if req.ContentLength == 0 {
- outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
+ req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries
}
- outreq.Header = cloneHeader(req.Header)
+ req.Close = false
- p.Director(outreq)
- outreq.Close = false
+ // if User-Agent is not set by client, then explicitly
+ // disable it so it's not set to default value by std lib
+ if _, ok := req.Header["User-Agent"]; !ok {
+ req.Header.Set("User-Agent", "")
+ }
- reqUpType := upgradeType(outreq.Header)
- removeConnectionHeaders(outreq.Header)
+ reqUpType := upgradeType(req.Header)
+ removeConnectionHeaders(req.Header)
// Remove hop-by-hop headers to the backend. Especially
// important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us.
for _, h := range hopHeaders {
- hv := outreq.Header.Get(h)
+ hv := req.Header.Get(h)
if hv == "" {
continue
}
if h == "Te" && hv == "trailers" {
- // Issue 21096: tell backend applications that
+ // Issue golang/go#21096: tell backend applications that
// care about trailer support that we support
// trailers. (We do, but we don't go out of
// our way to advertise that unless the
@@ -244,40 +321,72 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht
// worth mentioning)
continue
}
- outreq.Header.Del(h)
+ req.Header.Del(h)
}
// After stripping all the hop-by-hop connection headers above, add back any
// necessary for protocol upgrades, such as for websockets.
if reqUpType != "" {
- outreq.Header.Set("Connection", "Upgrade")
- outreq.Header.Set("Upgrade", reqUpType)
+ req.Header.Set("Connection", "Upgrade")
+ req.Header.Set("Upgrade", reqUpType)
}
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
- if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
+ if prior, ok := req.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
- outreq.Header.Set("X-Forwarded-For", clientIP)
+ req.Header.Set("X-Forwarded-For", clientIP)
}
- res, err := transport.RoundTrip(outreq)
+ return nil
+}
+
+// reverseProxy performs a round-trip to the given backend and processes the response with the client.
+// (This method is mostly the beginning of what was borrowed from the net/http/httputil package in the
+// Go standard library which was used as the foundation.)
+func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, upstream *Upstream) error {
+ upstream.Host.CountRequest(1)
+ defer upstream.Host.CountRequest(-1)
+
+ // point the request to this upstream
+ h.directRequest(req, upstream)
+
+ // do the round-trip
+ start := time.Now()
+ res, err := h.Transport.RoundTrip(req)
+ latency := time.Since(start)
if err != nil {
- p.getErrorHandler()(rw, outreq, err)
- return nil, err
+ return err
}
- // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
- if res.StatusCode == http.StatusSwitchingProtocols {
- if !p.modifyResponse(rw, res, outreq) {
- return res, nil
+ // update circuit breaker on current conditions
+ if upstream.cb != nil {
+ upstream.cb.RecordMetric(res.StatusCode, latency)
+ }
+
+ // perform passive health checks (if enabled)
+ if h.HealthChecks != nil && h.HealthChecks.Passive != nil {
+ // strike if the status code matches one that is "bad"
+ for _, badStatus := range h.HealthChecks.Passive.UnhealthyStatus {
+ if caddyhttp.StatusCodeMatches(res.StatusCode, badStatus) {
+ h.countFailure(upstream)
+ }
}
- p.handleUpgradeResponse(rw, outreq, res)
- return res, nil
+ // strike if the roundtrip took too long
+ if h.HealthChecks.Passive.UnhealthyLatency > 0 &&
+ latency >= time.Duration(h.HealthChecks.Passive.UnhealthyLatency) {
+ h.countFailure(upstream)
+ }
+ }
+
+ // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
+ if res.StatusCode == http.StatusSwitchingProtocols {
+ h.handleUpgradeResponse(rw, req, res)
+ return nil
}
removeConnectionHeaders(res.Header)
@@ -286,10 +395,6 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht
res.Header.Del(h)
}
- if !p.modifyResponse(rw, res, outreq) {
- return res, nil
- }
-
copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response,
@@ -305,15 +410,16 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht
rw.WriteHeader(res.StatusCode)
- err = p.copyResponse(rw, res.Body, p.flushInterval(req, res))
+ err = h.copyResponse(rw, res.Body, h.flushInterval(req, res))
if err != nil {
defer res.Body.Close()
// Since we're streaming the response, if we run into an error all we can do
- // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler
+ // is abort the request. Issue golang/go#23643: ReverseProxy should use ErrAbortHandler
// on read error while copying body.
+ // TODO: Look into whether we want to panic at all in our case...
if !shouldPanicOnCopyError(req) {
- p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
- return nil, err
+ // p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
+ return err
}
panic(http.ErrAbortHandler)
@@ -331,7 +437,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht
if len(res.Trailer) == announcedTrailers {
copyHeader(rw.Header(), res.Trailer)
- return res, nil
+ return nil
}
for k, vv := range res.Trailer {
@@ -341,21 +447,48 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht
}
}
- return res, nil
+ return nil
+}
+
+// tryAgain takes the time that the handler was initially invoked
+// as well as any error currently obtained and returns true if
+// another attempt should be made at proxying the request. If
+// true is returned, it has already blocked long enough before
+// the next retry (i.e. no more sleeping is needed). If false is
+// returned, the handler should stop trying to proxy the request.
+func (h Handler) tryAgain(start time.Time, proxyErr error) bool {
+ // if downstream has canceled the request, break
+ if proxyErr == context.Canceled {
+ return false
+ }
+ // if we've tried long enough, break
+ if time.Since(start) >= time.Duration(h.LoadBalancing.TryDuration) {
+ return false
+ }
+ // otherwise, wait and try the next available host
+ time.Sleep(time.Duration(h.LoadBalancing.TryInterval))
+ return true
}
-var inOurTests bool // whether we're in our own tests
+// directRequest modifies only req.URL so that it points to the
+// given upstream host. It must modify ONLY the request URL.
+func (h Handler) directRequest(req *http.Request, upstream *Upstream) {
+ if req.URL.Host == "" {
+ req.URL.Host = upstream.dialInfo.Address
+ }
+}
// shouldPanicOnCopyError reports whether the reverse proxy should
// panic with http.ErrAbortHandler. This is the right thing to do by
// default, but Go 1.10 and earlier did not, so existing unit tests
// weren't expecting panics. Only panic in our own tests, or when
// running under the HTTP server.
+// TODO: I don't know if we want this at all...
func shouldPanicOnCopyError(req *http.Request) bool {
- if inOurTests {
- // Our tests know to handle this panic.
- return true
- }
+ // if inOurTests {
+ // // Our tests know to handle this panic.
+ // return true
+ // }
if req.Context().Value(http.ServerContextKey) != nil {
// We seem to be running under an HTTP server, so
// it'll recover the panic.
@@ -366,146 +499,22 @@ func shouldPanicOnCopyError(req *http.Request) bool {
return false
}
-// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h.
-// See RFC 7230, section 6.1
-func removeConnectionHeaders(h http.Header) {
- if c := h.Get("Connection"); c != "" {
- for _, f := range strings.Split(c, ",") {
- if f = strings.TrimSpace(f); f != "" {
- h.Del(f)
- }
- }
- }
-}
-
-// flushInterval returns the p.FlushInterval value, conditionally
-// overriding its value for a specific request/response.
-func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration {
- resCT := res.Header.Get("Content-Type")
-
- // For Server-Sent Events responses, flush immediately.
- // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream
- if resCT == "text/event-stream" {
- return -1 // negative means immediately
- }
-
- // TODO: more specific cases? e.g. res.ContentLength == -1?
- return p.FlushInterval
-}
-
-func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
- if flushInterval != 0 {
- if wf, ok := dst.(writeFlusher); ok {
- mlw := &maxLatencyWriter{
- dst: wf,
- latency: flushInterval,
- }
- defer mlw.stop()
- dst = mlw
- }
- }
-
- var buf []byte
- if p.BufferPool != nil {
- buf = p.BufferPool.Get()
- defer p.BufferPool.Put(buf)
- }
- _, err := p.copyBuffer(dst, src, buf)
- return err
-}
-
-// copyBuffer returns any write errors or non-EOF read errors, and the amount
-// of bytes written.
-func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
- if len(buf) == 0 {
- buf = make([]byte, 32*1024)
- }
- var written int64
- for {
- nr, rerr := src.Read(buf)
- if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
- p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
- }
- if nr > 0 {
- nw, werr := dst.Write(buf[:nr])
- if nw > 0 {
- written += int64(nw)
- }
- if werr != nil {
- return written, werr
- }
- if nr != nw {
- return written, io.ErrShortWrite
- }
- }
- if rerr != nil {
- if rerr == io.EOF {
- rerr = nil
- }
- return written, rerr
+func copyHeader(dst, src http.Header) {
+ for k, vv := range src {
+ for _, v := range vv {
+ dst.Add(k, v)
}
}
}
-func (p *ReverseProxy) logf(format string, args ...interface{}) {
- if p.ErrorLog != nil {
- p.ErrorLog.Printf(format, args...)
- } else {
- log.Printf(format, args...)
- }
-}
-
-type writeFlusher interface {
- io.Writer
- http.Flusher
-}
-
-type maxLatencyWriter struct {
- dst writeFlusher
- latency time.Duration // non-zero; negative means to flush immediately
-
- mu sync.Mutex // protects t, flushPending, and dst.Flush
- t *time.Timer
- flushPending bool
-}
-
-func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
- m.mu.Lock()
- defer m.mu.Unlock()
- n, err = m.dst.Write(p)
- if m.latency < 0 {
- m.dst.Flush()
- return
- }
- if m.flushPending {
- return
- }
- if m.t == nil {
- m.t = time.AfterFunc(m.latency, m.delayedFlush)
- } else {
- m.t.Reset(m.latency)
- }
- m.flushPending = true
- return
-}
-
-func (m *maxLatencyWriter) delayedFlush() {
- m.mu.Lock()
- defer m.mu.Unlock()
- if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
- return
- }
- m.dst.Flush()
- m.flushPending = false
-}
-
-func (m *maxLatencyWriter) stop() {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.flushPending = false
- if m.t != nil {
- m.t.Stop()
+func cloneHeader(h http.Header) http.Header {
+ h2 := make(http.Header, len(h))
+ for k, vv := range h {
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ h2[k] = vv2
}
+ return h2
}
func upgradeType(h http.Header) string {
@@ -515,62 +524,71 @@ func upgradeType(h http.Header) string {
return strings.ToLower(h.Get("Upgrade"))
}
-func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
- reqUpType := upgradeType(req.Header)
- resUpType := upgradeType(res.Header)
- if reqUpType != resUpType {
- p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
- return
+func singleJoiningSlash(a, b string) string {
+ aslash := strings.HasSuffix(a, "/")
+ bslash := strings.HasPrefix(b, "/")
+ switch {
+ case aslash && bslash:
+ return a + b[1:]
+ case !aslash && !bslash:
+ return a + "/" + b
}
+ return a + b
+}
- copyHeader(res.Header, rw.Header())
-
- hj, ok := rw.(http.Hijacker)
- if !ok {
- p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
- return
- }
- backConn, ok := res.Body.(io.ReadWriteCloser)
- if !ok {
- p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
- return
+// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h.
+// See RFC 7230, section 6.1
+func removeConnectionHeaders(h http.Header) {
+ if c := h.Get("Connection"); c != "" {
+ for _, f := range strings.Split(c, ",") {
+ if f = strings.TrimSpace(f); f != "" {
+ h.Del(f)
+ }
+ }
}
- defer backConn.Close()
- conn, brw, err := hj.Hijack()
- if err != nil {
- p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
- return
- }
- defer conn.Close()
- res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
- if err := res.Write(brw); err != nil {
- p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
- return
- }
- if err := brw.Flush(); err != nil {
- p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
- return
- }
- errc := make(chan error, 1)
- spc := switchProtocolCopier{user: conn, backend: backConn}
- go spc.copyToBackend(errc)
- go spc.copyFromBackend(errc)
- <-errc
- return
}
-// switchProtocolCopier exists so goroutines proxying data back and
-// forth have nice names in stacks.
-type switchProtocolCopier struct {
- user, backend io.ReadWriter
+// LoadBalancing has parameters related to load balancing.
+type LoadBalancing struct {
+ SelectionPolicyRaw json.RawMessage `json:"selection_policy,omitempty"`
+ TryDuration caddy.Duration `json:"try_duration,omitempty"`
+ TryInterval caddy.Duration `json:"try_interval,omitempty"`
+
+ SelectionPolicy Selector `json:"-"`
}
-func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
- _, err := io.Copy(c.user, c.backend)
- errc <- err
+// Selector selects an available upstream from the pool.
+type Selector interface {
+ Select(UpstreamPool, *http.Request) *Upstream
}
-func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
- _, err := io.Copy(c.backend, c.user)
- errc <- err
+// Hop-by-hop headers. These are removed when sent to the backend.
+// As of RFC 7230, hop-by-hop headers are required to appear in the
+// Connection header field. These are the headers defined by the
+// obsoleted RFC 2616 (section 13.5.1) and are used for backward
+// compatibility.
+var hopHeaders = []string{
+ "Connection",
+ "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
+ "Keep-Alive",
+ "Proxy-Authenticate",
+ "Proxy-Authorization",
+ "Te", // canonicalized version of "TE"
+ "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
+ "Transfer-Encoding",
+ "Upgrade",
}
+
+// TODO: see if we can use this
+// var bufPool = sync.Pool{
+// New: func() interface{} {
+// return new(bytes.Buffer)
+// },
+// }
+
+// Interface guards
+var (
+ _ caddy.Provisioner = (*Handler)(nil)
+ _ caddy.CleanerUpper = (*Handler)(nil)
+ _ caddyhttp.MiddlewareHandler = (*Handler)(nil)
+)
diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go
new file mode 100644
index 0000000..5bb2d62
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go
@@ -0,0 +1,353 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reverseproxy
+
+import (
+ "fmt"
+ "hash/fnv"
+ weakrand "math/rand"
+ "net"
+ "net/http"
+ "sync/atomic"
+ "time"
+
+ "github.com/caddyserver/caddy/v2"
+)
+
+func init() {
+ caddy.RegisterModule(RandomSelection{})
+ caddy.RegisterModule(RandomChoiceSelection{})
+ caddy.RegisterModule(LeastConnSelection{})
+ caddy.RegisterModule(RoundRobinSelection{})
+ caddy.RegisterModule(FirstSelection{})
+ caddy.RegisterModule(IPHashSelection{})
+ caddy.RegisterModule(URIHashSelection{})
+ caddy.RegisterModule(HeaderHashSelection{})
+
+ weakrand.Seed(time.Now().UTC().UnixNano())
+}
+
+// RandomSelection is a policy that selects
+// an available host at random.
+type RandomSelection struct{}
+
+// CaddyModule returns the Caddy module information.
+func (RandomSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.random",
+ New: func() caddy.Module { return new(RandomSelection) },
+ }
+}
+
+// Select returns an available host, if any.
+func (r RandomSelection) Select(pool UpstreamPool, request *http.Request) *Upstream {
+ // use reservoir sampling because the number of available
+ // hosts isn't known: https://en.wikipedia.org/wiki/Reservoir_sampling
+ var randomHost *Upstream
+ var count int
+ for _, upstream := range pool {
+ if !upstream.Available() {
+ continue
+ }
+ // (n % 1 == 0) holds for all n, therefore a
+ // upstream will always be chosen if there is at
+ // least one available
+ count++
+ if (weakrand.Int() % count) == 0 {
+ randomHost = upstream
+ }
+ }
+ return randomHost
+}
+
+// RandomChoiceSelection is a policy that selects
+// two or more available hosts at random, then
+// chooses the one with the least load.
+type RandomChoiceSelection struct {
+ Choose int `json:"choose,omitempty"`
+}
+
+// CaddyModule returns the Caddy module information.
+func (RandomChoiceSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.random_choose",
+ New: func() caddy.Module { return new(RandomChoiceSelection) },
+ }
+}
+
+// Provision sets up r.
+func (r *RandomChoiceSelection) Provision(ctx caddy.Context) error {
+ if r.Choose == 0 {
+ r.Choose = 2
+ }
+ return nil
+}
+
+// Validate ensures that r's configuration is valid.
+func (r RandomChoiceSelection) Validate() error {
+ if r.Choose < 2 {
+ return fmt.Errorf("choose must be at least 2")
+ }
+ return nil
+}
+
+// Select returns an available host, if any.
+func (r RandomChoiceSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream {
+ k := r.Choose
+ if k > len(pool) {
+ k = len(pool)
+ }
+ choices := make([]*Upstream, k)
+ for i, upstream := range pool {
+ if !upstream.Available() {
+ continue
+ }
+ j := weakrand.Intn(i)
+ if j < k {
+ choices[j] = upstream
+ }
+ }
+ return leastRequests(choices)
+}
+
+// LeastConnSelection is a policy that selects the
+// host with the least active requests. If multiple
+// hosts have the same fewest number, one is chosen
+// randomly. The term "conn" or "connection" is used
+// in this policy name due to its similar meaning in
+// other software, but our load balancer actually
+// counts active requests rather than connections,
+// since these days requests are multiplexed onto
+// shared connections.
+type LeastConnSelection struct{}
+
+// CaddyModule returns the Caddy module information.
+func (LeastConnSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.least_conn",
+ New: func() caddy.Module { return new(LeastConnSelection) },
+ }
+}
+
+// Select selects the up host with the least number of connections in the
+// pool. If more than one host has the same least number of connections,
+// one of the hosts is chosen at random.
+func (LeastConnSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream {
+ var bestHost *Upstream
+ var count int
+ leastReqs := -1
+
+ for _, host := range pool {
+ if !host.Available() {
+ continue
+ }
+ numReqs := host.NumRequests()
+ if leastReqs == -1 || numReqs < leastReqs {
+ leastReqs = numReqs
+ count = 0
+ }
+
+ // among hosts with same least connections, perform a reservoir
+ // sample: https://en.wikipedia.org/wiki/Reservoir_sampling
+ if numReqs == leastReqs {
+ count++
+ if (weakrand.Int() % count) == 0 {
+ bestHost = host
+ }
+ }
+ }
+
+ return bestHost
+}
+
+// RoundRobinSelection is a policy that selects
+// a host based on round-robin ordering.
+type RoundRobinSelection struct {
+ robin uint32
+}
+
+// CaddyModule returns the Caddy module information.
+func (RoundRobinSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.round_robin",
+ New: func() caddy.Module { return new(RoundRobinSelection) },
+ }
+}
+
+// Select returns an available host, if any.
+func (r *RoundRobinSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream {
+ n := uint32(len(pool))
+ if n == 0 {
+ return nil
+ }
+ for i := uint32(0); i < n; i++ {
+ atomic.AddUint32(&r.robin, 1)
+ host := pool[r.robin%n]
+ if host.Available() {
+ return host
+ }
+ }
+ return nil
+}
+
+// FirstSelection is a policy that selects
+// the first available host.
+type FirstSelection struct{}
+
+// CaddyModule returns the Caddy module information.
+func (FirstSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.first",
+ New: func() caddy.Module { return new(FirstSelection) },
+ }
+}
+
+// Select returns an available host, if any.
+func (FirstSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream {
+ for _, host := range pool {
+ if host.Available() {
+ return host
+ }
+ }
+ return nil
+}
+
+// IPHashSelection is a policy that selects a host
+// based on hashing the remote IP of the request.
+type IPHashSelection struct{}
+
+// CaddyModule returns the Caddy module information.
+func (IPHashSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.ip_hash",
+ New: func() caddy.Module { return new(IPHashSelection) },
+ }
+}
+
+// Select returns an available host, if any.
+func (IPHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream {
+ clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
+ if err != nil {
+ clientIP = req.RemoteAddr
+ }
+ return hostByHashing(pool, clientIP)
+}
+
+// URIHashSelection is a policy that selects a
+// host by hashing the request URI.
+type URIHashSelection struct{}
+
+// CaddyModule returns the Caddy module information.
+func (URIHashSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.uri_hash",
+ New: func() caddy.Module { return new(URIHashSelection) },
+ }
+}
+
+// Select returns an available host, if any.
+func (URIHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream {
+ return hostByHashing(pool, req.RequestURI)
+}
+
+// HeaderHashSelection is a policy that selects
+// a host based on a given request header.
+type HeaderHashSelection struct {
+ Field string `json:"field,omitempty"`
+}
+
+// CaddyModule returns the Caddy module information.
+func (HeaderHashSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.header",
+ New: func() caddy.Module { return new(HeaderHashSelection) },
+ }
+}
+
+// Select returns an available host, if any.
+func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream {
+ if s.Field == "" {
+ return nil
+ }
+ val := req.Header.Get(s.Field)
+ if val == "" {
+ return RandomSelection{}.Select(pool, req)
+ }
+ return hostByHashing(pool, val)
+}
+
+// leastRequests returns the host with the
+// least number of active requests to it.
+// If more than one host has the same
+// least number of active requests, then
+// one of those is chosen at random.
+func leastRequests(upstreams []*Upstream) *Upstream {
+ if len(upstreams) == 0 {
+ return nil
+ }
+ var best []*Upstream
+ var bestReqs int
+ for _, upstream := range upstreams {
+ reqs := upstream.NumRequests()
+ if reqs == 0 {
+ return upstream
+ }
+ if reqs <= bestReqs {
+ bestReqs = reqs
+ best = append(best, upstream)
+ }
+ }
+ return best[weakrand.Intn(len(best))]
+}
+
+// hostByHashing returns an available host
+// from pool based on a hashable string s.
+func hostByHashing(pool []*Upstream, s string) *Upstream {
+ poolLen := uint32(len(pool))
+ if poolLen == 0 {
+ return nil
+ }
+ index := hash(s) % poolLen
+ for i := uint32(0); i < poolLen; i++ {
+ index += i
+ upstream := pool[index%poolLen]
+ if upstream.Available() {
+ return upstream
+ }
+ }
+ return nil
+}
+
+// hash calculates a fast hash based on s.
+func hash(s string) uint32 {
+ h := fnv.New32a()
+ h.Write([]byte(s))
+ return h.Sum32()
+}
+
+// Interface guards
+var (
+ _ Selector = (*RandomSelection)(nil)
+ _ Selector = (*RandomChoiceSelection)(nil)
+ _ Selector = (*LeastConnSelection)(nil)
+ _ Selector = (*RoundRobinSelection)(nil)
+ _ Selector = (*FirstSelection)(nil)
+ _ Selector = (*IPHashSelection)(nil)
+ _ Selector = (*URIHashSelection)(nil)
+ _ Selector = (*HeaderHashSelection)(nil)
+
+ _ caddy.Validator = (*RandomChoiceSelection)(nil)
+ _ caddy.Provisioner = (*RandomChoiceSelection)(nil)
+)
diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
new file mode 100644
index 0000000..e9939d6
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
@@ -0,0 +1,273 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reverseproxy
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func testPool() UpstreamPool {
+ return UpstreamPool{
+ {Host: new(upstreamHost)},
+ {Host: new(upstreamHost)},
+ {Host: new(upstreamHost)},
+ }
+}
+
+func TestRoundRobinPolicy(t *testing.T) {
+ pool := testPool()
+ rrPolicy := new(RoundRobinSelection)
+ req, _ := http.NewRequest("GET", "/", nil)
+
+ h := rrPolicy.Select(pool, req)
+ // First selected host is 1, because counter starts at 0
+ // and increments before host is selected
+ if h != pool[1] {
+ t.Error("Expected first round robin host to be second host in the pool.")
+ }
+ h = rrPolicy.Select(pool, req)
+ if h != pool[2] {
+ t.Error("Expected second round robin host to be third host in the pool.")
+ }
+ h = rrPolicy.Select(pool, req)
+ if h != pool[0] {
+ t.Error("Expected third round robin host to be first host in the pool.")
+ }
+ // mark host as down
+ pool[1].SetHealthy(false)
+ h = rrPolicy.Select(pool, req)
+ if h != pool[2] {
+ t.Error("Expected to skip down host.")
+ }
+ // mark host as up
+ pool[1].SetHealthy(true)
+
+ h = rrPolicy.Select(pool, req)
+ if h == pool[2] {
+ t.Error("Expected to balance evenly among healthy hosts")
+ }
+ // mark host as full
+ pool[1].CountRequest(1)
+ pool[1].MaxRequests = 1
+ h = rrPolicy.Select(pool, req)
+ if h != pool[2] {
+ t.Error("Expected to skip full host.")
+ }
+}
+
+func TestLeastConnPolicy(t *testing.T) {
+ pool := testPool()
+ lcPolicy := new(LeastConnSelection)
+ req, _ := http.NewRequest("GET", "/", nil)
+
+ pool[0].CountRequest(10)
+ pool[1].CountRequest(10)
+ h := lcPolicy.Select(pool, req)
+ if h != pool[2] {
+ t.Error("Expected least connection host to be third host.")
+ }
+ pool[2].CountRequest(100)
+ h = lcPolicy.Select(pool, req)
+ if h != pool[0] && h != pool[1] {
+ t.Error("Expected least connection host to be first or second host.")
+ }
+}
+
+func TestIPHashPolicy(t *testing.T) {
+ pool := testPool()
+ ipHash := new(IPHashSelection)
+ req, _ := http.NewRequest("GET", "/", nil)
+
+ // We should be able to predict where every request is routed.
+ req.RemoteAddr = "172.0.0.1:80"
+ h := ipHash.Select(pool, req)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+ req.RemoteAddr = "172.0.0.2:80"
+ h = ipHash.Select(pool, req)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+ req.RemoteAddr = "172.0.0.3:80"
+ h = ipHash.Select(pool, req)
+ if h != pool[2] {
+ t.Error("Expected ip hash policy host to be the third host.")
+ }
+ req.RemoteAddr = "172.0.0.4:80"
+ h = ipHash.Select(pool, req)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+
+ // we should get the same results without a port
+ req.RemoteAddr = "172.0.0.1"
+ h = ipHash.Select(pool, req)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+ req.RemoteAddr = "172.0.0.2"
+ h = ipHash.Select(pool, req)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+ req.RemoteAddr = "172.0.0.3"
+ h = ipHash.Select(pool, req)
+ if h != pool[2] {
+ t.Error("Expected ip hash policy host to be the third host.")
+ }
+ req.RemoteAddr = "172.0.0.4"
+ h = ipHash.Select(pool, req)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+
+ // we should get a healthy host if the original host is unhealthy and a
+ // healthy host is available
+ req.RemoteAddr = "172.0.0.1"
+ pool[1].SetHealthy(false)
+ h = ipHash.Select(pool, req)
+ if h != pool[2] {
+ t.Error("Expected ip hash policy host to be the third host.")
+ }
+
+ req.RemoteAddr = "172.0.0.2"
+ h = ipHash.Select(pool, req)
+ if h != pool[2] {
+ t.Error("Expected ip hash policy host to be the third host.")
+ }
+ pool[1].SetHealthy(true)
+
+ req.RemoteAddr = "172.0.0.3"
+ pool[2].SetHealthy(false)
+ h = ipHash.Select(pool, req)
+ if h != pool[0] {
+ t.Error("Expected ip hash policy host to be the first host.")
+ }
+ req.RemoteAddr = "172.0.0.4"
+ h = ipHash.Select(pool, req)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+
+ // We should be able to resize the host pool and still be able to predict
+ // where a req will be routed with the same IP's used above
+ pool = UpstreamPool{
+ {Host: new(upstreamHost)},
+ {Host: new(upstreamHost)},
+ }
+ req.RemoteAddr = "172.0.0.1:80"
+ h = ipHash.Select(pool, req)
+ if h != pool[0] {
+ t.Error("Expected ip hash policy host to be the first host.")
+ }
+ req.RemoteAddr = "172.0.0.2:80"
+ h = ipHash.Select(pool, req)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+ req.RemoteAddr = "172.0.0.3:80"
+ h = ipHash.Select(pool, req)
+ if h != pool[0] {
+ t.Error("Expected ip hash policy host to be the first host.")
+ }
+ req.RemoteAddr = "172.0.0.4:80"
+ h = ipHash.Select(pool, req)
+ if h != pool[1] {
+ t.Error("Expected ip hash policy host to be the second host.")
+ }
+
+ // We should get nil when there are no healthy hosts
+ pool[0].SetHealthy(false)
+ pool[1].SetHealthy(false)
+ h = ipHash.Select(pool, req)
+ if h != nil {
+ t.Error("Expected ip hash policy host to be nil.")
+ }
+}
+
+func TestFirstPolicy(t *testing.T) {
+ pool := testPool()
+ firstPolicy := new(FirstSelection)
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+
+ h := firstPolicy.Select(pool, req)
+ if h != pool[0] {
+ t.Error("Expected first policy host to be the first host.")
+ }
+
+ pool[0].SetHealthy(false)
+ h = firstPolicy.Select(pool, req)
+ if h != pool[1] {
+ t.Error("Expected first policy host to be the second host.")
+ }
+}
+
+func TestURIHashPolicy(t *testing.T) {
+ pool := testPool()
+ uriPolicy := new(URIHashSelection)
+
+ request := httptest.NewRequest(http.MethodGet, "/test", nil)
+ h := uriPolicy.Select(pool, request)
+ if h != pool[0] {
+ t.Error("Expected uri policy host to be the first host.")
+ }
+
+ pool[0].SetHealthy(false)
+ h = uriPolicy.Select(pool, request)
+ if h != pool[1] {
+ t.Error("Expected uri policy host to be the first host.")
+ }
+
+ request = httptest.NewRequest(http.MethodGet, "/test_2", nil)
+ h = uriPolicy.Select(pool, request)
+ if h != pool[1] {
+ t.Error("Expected uri policy host to be the second host.")
+ }
+
+ // We should be able to resize the host pool and still be able to predict
+ // where a request will be routed with the same URI's used above
+ pool = UpstreamPool{
+ {Host: new(upstreamHost)},
+ {Host: new(upstreamHost)},
+ }
+
+ request = httptest.NewRequest(http.MethodGet, "/test", nil)
+ h = uriPolicy.Select(pool, request)
+ if h != pool[0] {
+ t.Error("Expected uri policy host to be the first host.")
+ }
+
+ pool[0].SetHealthy(false)
+ h = uriPolicy.Select(pool, request)
+ if h != pool[1] {
+ t.Error("Expected uri policy host to be the first host.")
+ }
+
+ request = httptest.NewRequest(http.MethodGet, "/test_2", nil)
+ h = uriPolicy.Select(pool, request)
+ if h != pool[1] {
+ t.Error("Expected uri policy host to be the second host.")
+ }
+
+ pool[0].SetHealthy(false)
+ pool[1].SetHealthy(false)
+ h = uriPolicy.Select(pool, request)
+ if h != nil {
+ t.Error("Expected uri policy policy host to be nil.")
+ }
+}
diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go
new file mode 100644
index 0000000..a3b711b
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/streaming.go
@@ -0,0 +1,223 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Most of the code in this file was initially borrowed from the Go
+// standard library, which has this copyright notice:
+// Copyright 2011 The Go Authors.
+
+package reverseproxy
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "sync"
+ "time"
+)
+
+func (h Handler) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
+ reqUpType := upgradeType(req.Header)
+ resUpType := upgradeType(res.Header)
+ if reqUpType != resUpType {
+ // TODO: figure out our own error handling
+ // p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
+ return
+ }
+
+ copyHeader(res.Header, rw.Header())
+
+ hj, ok := rw.(http.Hijacker)
+ if !ok {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
+ return
+ }
+ backConn, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
+ return
+ }
+ defer backConn.Close()
+ conn, brw, err := hj.Hijack()
+ if err != nil {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
+ return
+ }
+ defer conn.Close()
+ res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
+ if err := res.Write(brw); err != nil {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
+ return
+ }
+ if err := brw.Flush(); err != nil {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
+ return
+ }
+ errc := make(chan error, 1)
+ spc := switchProtocolCopier{user: conn, backend: backConn}
+ go spc.copyToBackend(errc)
+ go spc.copyFromBackend(errc)
+ <-errc
+ return
+}
+
+// flushInterval returns the p.FlushInterval value, conditionally
+// overriding its value for a specific request/response.
+func (h Handler) flushInterval(req *http.Request, res *http.Response) time.Duration {
+ resCT := res.Header.Get("Content-Type")
+
+ // For Server-Sent Events responses, flush immediately.
+ // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream
+ if resCT == "text/event-stream" {
+ return -1 // negative means immediately
+ }
+
+ // TODO: more specific cases? e.g. res.ContentLength == -1? (this TODO is from the std lib)
+ return time.Duration(h.FlushInterval)
+}
+
+func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
+ if flushInterval != 0 {
+ if wf, ok := dst.(writeFlusher); ok {
+ mlw := &maxLatencyWriter{
+ dst: wf,
+ latency: flushInterval,
+ }
+ defer mlw.stop()
+ dst = mlw
+ }
+ }
+
+ // TODO: Figure out how we want to do this... using custom buffer pool type seems unnecessary
+ // or maybe it is, depending on how we want to handle errors,
+ // see: https://github.com/golang/go/issues/21814
+ // buf := bufPool.Get().(*bytes.Buffer)
+ // buf.Reset()
+ // defer bufPool.Put(buf)
+ // _, err := io.CopyBuffer(dst, src, )
+ var buf []byte
+ // if h.BufferPool != nil {
+ // buf = h.BufferPool.Get()
+ // defer h.BufferPool.Put(buf)
+ // }
+ _, err := h.copyBuffer(dst, src, buf)
+ return err
+}
+
+// copyBuffer returns any write errors or non-EOF read errors, and the amount
+// of bytes written.
+func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
+ if len(buf) == 0 {
+ buf = make([]byte, 32*1024)
+ }
+ var written int64
+ for {
+ nr, rerr := src.Read(buf)
+ if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
+ // TODO: this could be useful to know (indeed, it revealed an error in our
+ // fastcgi PoC earlier; but it's this single error report here that necessitates
+ // a function separate from io.CopyBuffer, since io.CopyBuffer does not distinguish
+ // between read or write errors; in a reverse proxy situation, write errors are not
+ // something we need to report to the client, but read errors are a problem on our
+ // end for sure. so we need to decide what we want.)
+ // p.logf("copyBuffer: ReverseProxy read error during body copy: %v", rerr)
+ }
+ if nr > 0 {
+ nw, werr := dst.Write(buf[:nr])
+ if nw > 0 {
+ written += int64(nw)
+ }
+ if werr != nil {
+ return written, werr
+ }
+ if nr != nw {
+ return written, io.ErrShortWrite
+ }
+ }
+ if rerr != nil {
+ if rerr == io.EOF {
+ rerr = nil
+ }
+ return written, rerr
+ }
+ }
+}
+
+type writeFlusher interface {
+ io.Writer
+ http.Flusher
+}
+
+type maxLatencyWriter struct {
+ dst writeFlusher
+ latency time.Duration // non-zero; negative means to flush immediately
+
+ mu sync.Mutex // protects t, flushPending, and dst.Flush
+ t *time.Timer
+ flushPending bool
+}
+
+func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ n, err = m.dst.Write(p)
+ if m.latency < 0 {
+ m.dst.Flush()
+ return
+ }
+ if m.flushPending {
+ return
+ }
+ if m.t == nil {
+ m.t = time.AfterFunc(m.latency, m.delayedFlush)
+ } else {
+ m.t.Reset(m.latency)
+ }
+ m.flushPending = true
+ return
+}
+
+func (m *maxLatencyWriter) delayedFlush() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
+ return
+ }
+ m.dst.Flush()
+ m.flushPending = false
+}
+
+func (m *maxLatencyWriter) stop() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.flushPending = false
+ if m.t != nil {
+ m.t.Stop()
+ }
+}
+
+// switchProtocolCopier exists so goroutines proxying data back and
+// forth have nice names in stacks.
+type switchProtocolCopier struct {
+ user, backend io.ReadWriter
+}
+
+func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
+ _, err := io.Copy(c.user, c.backend)
+ errc <- err
+}
+
+func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
+ _, err := io.Copy(c.backend, c.user)
+ errc <- err
+}
diff --git a/modules/caddyhttp/reverseproxy/upstream.go b/modules/caddyhttp/reverseproxy/upstream.go
deleted file mode 100755
index 1f0693e..0000000
--- a/modules/caddyhttp/reverseproxy/upstream.go
+++ /dev/null
@@ -1,450 +0,0 @@
-// Copyright 2015 Matthew Holt and The Caddy Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package reverseproxy implements a load-balanced reverse proxy.
-package reverseproxy
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "math/rand"
- "net"
- "net/http"
- "net/url"
- "strings"
- "sync"
- "sync/atomic"
- "time"
-
- "github.com/caddyserver/caddy/v2"
- "github.com/caddyserver/caddy/v2/modules/caddyhttp"
-)
-
-// CircuitBreaker defines the functionality of a circuit breaker module.
-type CircuitBreaker interface {
- Ok() bool
- RecordMetric(statusCode int, latency time.Duration)
-}
-
-type noopCircuitBreaker struct{}
-
-func (ncb noopCircuitBreaker) RecordMetric(statusCode int, latency time.Duration) {}
-func (ncb noopCircuitBreaker) Ok() bool {
- return true
-}
-
-const (
- // TypeBalanceRoundRobin represents the value to use for configuring a load balanced reverse proxy to use round robin load balancing.
- TypeBalanceRoundRobin = iota
-
- // TypeBalanceRandom represents the value to use for configuring a load balanced reverse proxy to use random load balancing.
- TypeBalanceRandom
-
- // TODO: add random with two choices
-
- // msgNoHealthyUpstreams is returned if there are no upstreams that are healthy to proxy a request to
- msgNoHealthyUpstreams = "No healthy upstreams."
-
- // by default perform health checks every 30 seconds
- defaultHealthCheckDur = time.Second * 30
-
- // used when an upstream is unhealthy, health checks can be configured to perform at a faster rate
- defaultFastHealthCheckDur = time.Second * 1
-)
-
-var (
- // defaultTransport is the default transport to use for the reverse proxy.
- defaultTransport = &http.Transport{
- Dial: (&net.Dialer{
- Timeout: 5 * time.Second,
- }).Dial,
- TLSHandshakeTimeout: 5 * time.Second,
- }
-
- // defaultHTTPClient is the default http client to use for the healthchecker.
- defaultHTTPClient = &http.Client{
- Timeout: time.Second * 10,
- Transport: defaultTransport,
- }
-
- // typeMap maps caddy load balance configuration to the internal representation of the loadbalance algorithm type.
- typeMap = map[string]int{
- "round_robin": TypeBalanceRoundRobin,
- "random": TypeBalanceRandom,
- }
-)
-
-// NewLoadBalancedReverseProxy returns a collection of Upstreams that are to be loadbalanced.
-func NewLoadBalancedReverseProxy(lb *LoadBalanced, ctx caddy.Context) error {
- // set defaults
- if lb.NoHealthyUpstreamsMessage == "" {
- lb.NoHealthyUpstreamsMessage = msgNoHealthyUpstreams
- }
-
- if lb.TryInterval == "" {
- lb.TryInterval = "20s"
- }
-
- // set request retry interval
- ti, err := time.ParseDuration(lb.TryInterval)
- if err != nil {
- return fmt.Errorf("NewLoadBalancedReverseProxy: %v", err.Error())
- }
- lb.tryInterval = ti
-
- // set load balance algorithm
- t, ok := typeMap[lb.LoadBalanceType]
- if !ok {
- t = TypeBalanceRandom
- }
- lb.loadBalanceType = t
-
- // setup each upstream
- var us []*upstream
- for _, uc := range lb.Upstreams {
- // pass the upstream decr and incr methods to keep track of unhealthy nodes
- nu, err := newUpstream(uc, lb.decrUnhealthy, lb.incrUnhealthy)
- if err != nil {
- return err
- }
-
- // setup any configured circuit breakers
- var cbModule = "http.handlers.reverse_proxy.circuit_breaker"
- var cb CircuitBreaker
-
- if uc.CircuitBreaker != nil {
- if _, err := caddy.GetModule(cbModule); err == nil {
- val, err := ctx.LoadModule(cbModule, uc.CircuitBreaker)
- if err == nil {
- cbv, ok := val.(CircuitBreaker)
- if ok {
- cb = cbv
- } else {
- fmt.Printf("\nerr: %v; cannot load circuit_breaker, using noop", err.Error())
- cb = noopCircuitBreaker{}
- }
- } else {
- fmt.Printf("\nerr: %v; cannot load circuit_breaker, using noop", err.Error())
- cb = noopCircuitBreaker{}
- }
- } else {
- fmt.Println("circuit_breaker module not loaded, using noop")
- cb = noopCircuitBreaker{}
- }
- } else {
- cb = noopCircuitBreaker{}
- }
- nu.CB = cb
-
- // start a healthcheck worker which will periodically check to see if an upstream is healthy
- // to proxy requests to.
- nu.healthChecker = NewHealthCheckWorker(nu, defaultHealthCheckDur, defaultHTTPClient)
-
- // TODO :- if path is empty why does this empty the entire Target?
- // nu.Target.Path = uc.HealthCheckPath
-
- nu.healthChecker.ScheduleChecks(nu.Target.String())
- lb.HealthCheckers = append(lb.HealthCheckers, nu.healthChecker)
-
- us = append(us, nu)
- }
-
- lb.upstreams = us
-
- return nil
-}
-
-// LoadBalanced represents a collection of upstream hosts that are loadbalanced. It
-// contains multiple features like health checking and circuit breaking functionality
-// for upstreams.
-type LoadBalanced struct {
- mu sync.Mutex
- numUnhealthy int32
- selectedServer int // used during round robin load balancing
- loadBalanceType int
- tryInterval time.Duration
- upstreams []*upstream
-
- // The following struct fields are set by caddy configuration.
- // TryInterval is the max duration for which request retrys will be performed for a request.
- TryInterval string `json:"try_interval,omitempty"`
-
- // Upstreams are the configs for upstream hosts
- Upstreams []*UpstreamConfig `json:"upstreams,omitempty"`
-
- // LoadBalanceType is the string representation of what loadbalancing algorithm to use. i.e. "random" or "round_robin".
- LoadBalanceType string `json:"load_balance_type,omitempty"`
-
- // NoHealthyUpstreamsMessage is returned as a response when there are no healthy upstreams to loadbalance to.
- NoHealthyUpstreamsMessage string `json:"no_healthy_upstreams_message,omitempty"`
-
- // TODO :- store healthcheckers as package level state where each upstream gets a single healthchecker
- // currently a healthchecker is created for each upstream defined, even if a healthchecker was previously created
- // for that upstream
- HealthCheckers []*HealthChecker `json:"health_checkers,omitempty"`
-}
-
-// Cleanup stops all health checkers on a loadbalanced reverse proxy.
-func (lb *LoadBalanced) Cleanup() error {
- for _, hc := range lb.HealthCheckers {
- hc.Stop()
- }
-
- return nil
-}
-
-// Provision sets up a new loadbalanced reverse proxy.
-func (lb *LoadBalanced) Provision(ctx caddy.Context) error {
- return NewLoadBalancedReverseProxy(lb, ctx)
-}
-
-// ServeHTTP implements the caddyhttp.MiddlewareHandler interface to
-// dispatch an HTTP request to the proper server.
-func (lb *LoadBalanced) ServeHTTP(w http.ResponseWriter, r *http.Request, _ caddyhttp.Handler) error {
- // ensure requests don't hang if an upstream does not respond or is not eventually healthy
- var u *upstream
- var done bool
-
- retryTimer := time.NewTicker(lb.tryInterval)
- defer retryTimer.Stop()
-
- go func() {
- select {
- case <-retryTimer.C:
- done = true
- }
- }()
-
- // keep trying to get an available upstream to process the request
- for {
- switch lb.loadBalanceType {
- case TypeBalanceRandom:
- u = lb.random()
- case TypeBalanceRoundRobin:
- u = lb.roundRobin()
- }
-
- // if we can't get an upstream and our retry interval has ended return an error response
- if u == nil && done {
- w.WriteHeader(http.StatusBadGateway)
- fmt.Fprint(w, lb.NoHealthyUpstreamsMessage)
-
- return fmt.Errorf(msgNoHealthyUpstreams)
- }
-
- // attempt to get an available upstream
- if u == nil {
- continue
- }
-
- start := time.Now()
-
- // if we get an error retry until we get a healthy upstream
- res, err := u.ReverseProxy.ServeHTTP(w, r)
- if err != nil {
- if err == context.Canceled {
- return nil
- }
-
- continue
- }
-
- // record circuit breaker metrics
- go u.CB.RecordMetric(res.StatusCode, time.Now().Sub(start))
-
- return nil
- }
-}
-
-// incrUnhealthy increments the amount of unhealthy nodes in a loadbalancer.
-func (lb *LoadBalanced) incrUnhealthy() {
- atomic.AddInt32(&lb.numUnhealthy, 1)
-}
-
-// decrUnhealthy decrements the amount of unhealthy nodes in a loadbalancer.
-func (lb *LoadBalanced) decrUnhealthy() {
- atomic.AddInt32(&lb.numUnhealthy, -1)
-}
-
-// roundRobin implements a round robin load balancing algorithm to select
-// which server to forward requests to.
-func (lb *LoadBalanced) roundRobin() *upstream {
- if atomic.LoadInt32(&lb.numUnhealthy) == int32(len(lb.upstreams)) {
- return nil
- }
-
- selected := lb.upstreams[lb.selectedServer]
-
- lb.mu.Lock()
- lb.selectedServer++
- if lb.selectedServer >= len(lb.upstreams) {
- lb.selectedServer = 0
- }
- lb.mu.Unlock()
-
- if selected.IsHealthy() && selected.CB.Ok() {
- return selected
- }
-
- return nil
-}
-
-// random implements a random server selector for load balancing.
-func (lb *LoadBalanced) random() *upstream {
- if atomic.LoadInt32(&lb.numUnhealthy) == int32(len(lb.upstreams)) {
- return nil
- }
-
- n := rand.Int() % len(lb.upstreams)
- selected := lb.upstreams[n]
-
- if selected.IsHealthy() && selected.CB.Ok() {
- return selected
- }
-
- return nil
-}
-
-// UpstreamConfig represents the config of an upstream.
-type UpstreamConfig struct {
- // Host is the host name of the upstream server.
- Host string `json:"host,omitempty"`
-
- // FastHealthCheckDuration is the duration for which a health check is performed when a node is considered unhealthy.
- FastHealthCheckDuration string `json:"fast_health_check_duration,omitempty"`
-
- CircuitBreaker json.RawMessage `json:"circuit_breaker,omitempty"`
-
- // // CircuitBreakerConfig is the config passed to setup a circuit breaker.
- // CircuitBreakerConfig *circuitbreaker.Config `json:"circuit_breaker,omitempty"`
- circuitbreaker CircuitBreaker
-
- // HealthCheckDuration is the default duration for which a health check is performed.
- HealthCheckDuration string `json:"health_check_duration,omitempty"`
-
- // HealthCheckPath is the path at the upstream host to use for healthchecks.
- HealthCheckPath string `json:"health_check_path,omitempty"`
-}
-
-// upstream represents an upstream host.
-type upstream struct {
- Healthy int32 // 0 = false, 1 = true
- Target *url.URL
- ReverseProxy *ReverseProxy
- Incr func()
- Decr func()
- CB CircuitBreaker
- healthChecker *HealthChecker
- healthCheckDur time.Duration
- fastHealthCheckDur time.Duration
-}
-
-// newUpstream returns a new upstream.
-func newUpstream(uc *UpstreamConfig, d func(), i func()) (*upstream, error) {
- host := strings.TrimSpace(uc.Host)
- protoIdx := strings.Index(host, "://")
- if protoIdx == -1 || len(host[:protoIdx]) == 0 {
- return nil, fmt.Errorf("protocol is required for host")
- }
-
- hostURL, err := url.Parse(host)
- if err != nil {
- return nil, err
- }
-
- // parse healthcheck durations
- hcd, err := time.ParseDuration(uc.HealthCheckDuration)
- if err != nil {
- hcd = defaultHealthCheckDur
- }
-
- fhcd, err := time.ParseDuration(uc.FastHealthCheckDuration)
- if err != nil {
- fhcd = defaultFastHealthCheckDur
- }
-
- u := upstream{
- healthCheckDur: hcd,
- fastHealthCheckDur: fhcd,
- Target: hostURL,
- Decr: d,
- Incr: i,
- Healthy: int32(0), // assume is unhealthy on start
- }
-
- u.ReverseProxy = newReverseProxy(hostURL, u.SetHealthiness)
- return &u, nil
-}
-
-// SetHealthiness sets whether an upstream is healthy or not. The health check worker is updated to
-// perform checks faster if a node is unhealthy.
-func (u *upstream) SetHealthiness(ok bool) {
- h := atomic.LoadInt32(&u.Healthy)
- var wasHealthy bool
- if h == 1 {
- wasHealthy = true
- } else {
- wasHealthy = false
- }
-
- if ok {
- u.healthChecker.Ticker = time.NewTicker(u.healthCheckDur)
-
- if !wasHealthy {
- atomic.AddInt32(&u.Healthy, 1)
- u.Decr()
- }
- } else {
- u.healthChecker.Ticker = time.NewTicker(u.fastHealthCheckDur)
-
- if wasHealthy {
- atomic.AddInt32(&u.Healthy, -1)
- u.Incr()
- }
- }
-}
-
-// IsHealthy returns whether an Upstream is healthy or not.
-func (u *upstream) IsHealthy() bool {
- i := atomic.LoadInt32(&u.Healthy)
- if i == 1 {
- return true
- }
-
- return false
-}
-
-// newReverseProxy returns a new reverse proxy handler.
-func newReverseProxy(target *url.URL, setHealthiness func(bool)) *ReverseProxy {
- errorHandler := func(w http.ResponseWriter, r *http.Request, err error) {
- // we don't need to worry about cancelled contexts since this doesn't necessarilly mean that
- // the upstream is unhealthy.
- if err != context.Canceled {
- setHealthiness(false)
- }
- }
-
- rp := NewSingleHostReverseProxy(target)
- rp.ErrorHandler = errorHandler
- rp.Transport = defaultTransport // use default transport that times out in 5 seconds
- return rp
-}
-
-// Interface guards
-var (
- _ caddyhttp.MiddlewareHandler = (*LoadBalanced)(nil)
- _ caddy.Provisioner = (*LoadBalanced)(nil)
- _ caddy.CleanerUpper = (*LoadBalanced)(nil)
-)