diff options
Diffstat (limited to 'modules/caddyhttp/reverseproxy')
-rw-r--r-- | modules/caddyhttp/reverseproxy/caddyfile.go | 486 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go | 54 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/fastcgi/client.go | 578 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/fastcgi/client_test.go | 301 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go | 301 | ||||
-rwxr-xr-x | modules/caddyhttp/reverseproxy/healthchecker.go | 86 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/healthchecks.go | 270 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/hosts.go | 193 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/httptransport.go | 208 | ||||
-rwxr-xr-x | modules/caddyhttp/reverseproxy/module.go | 53 | ||||
-rw-r--r--[-rwxr-xr-x] | modules/caddyhttp/reverseproxy/reverseproxy.go | 784 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/selectionpolicies.go | 353 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/selectionpolicies_test.go | 273 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/streaming.go | 223 | ||||
-rwxr-xr-x | modules/caddyhttp/reverseproxy/upstream.go | 450 |
15 files changed, 3641 insertions, 972 deletions
diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go new file mode 100644 index 0000000..ffa3ca0 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/caddyfile.go @@ -0,0 +1,486 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import ( + "net/http" + "strconv" + "strings" + "time" + + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig" + "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" + "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" + "github.com/dustin/go-humanize" +) + +func init() { + httpcaddyfile.RegisterHandlerDirective("reverse_proxy", parseCaddyfile) +} + +func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) { + rp := new(Handler) + err := rp.UnmarshalCaddyfile(h.Dispenser) + if err != nil { + return nil, err + } + return rp, nil +} + +// UnmarshalCaddyfile sets up the handler from Caddyfile tokens. Syntax: +// +// reverse_proxy [<matcher>] [<upstreams...>] { +// # upstreams +// to <upstreams...> +// +// # load balancing +// lb_policy <name> [<options...>] +// lb_try_duration <duration> +// lb_try_interval <interval> +// +// # active health checking +// health_path <path> +// health_port <port> +// health_interval <interval> +// health_timeout <duration> +// health_status <status> +// health_body <regexp> +// +// # passive health checking +// max_fails <num> +// fail_duration <duration> +// max_conns <num> +// unhealthy_status <status> +// unhealthy_latency <duration> +// +// # round trip +// transport <name> { +// ... +// } +// } +// +func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + for d.Next() { + for _, up := range d.RemainingArgs() { + h.Upstreams = append(h.Upstreams, &Upstream{ + Dial: up, + }) + } + + for d.NextBlock() { + switch d.Val() { + case "to": + args := d.RemainingArgs() + if len(args) == 0 { + return d.ArgErr() + } + for _, up := range args { + h.Upstreams = append(h.Upstreams, &Upstream{ + Dial: up, + }) + } + + case "lb_policy": + if !d.NextArg() { + return d.ArgErr() + } + if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil { + return d.Err("load balancing selection policy already specified") + } + name := d.Val() + mod, err := caddy.GetModule("http.handlers.reverse_proxy.selection_policies." + name) + if err != nil { + return d.Errf("getting load balancing policy module '%s': %v", mod.Name, err) + } + unm, ok := mod.New().(caddyfile.Unmarshaler) + if !ok { + return d.Errf("load balancing policy module '%s' is not a Caddyfile unmarshaler", mod.Name) + } + err = unm.UnmarshalCaddyfile(d.NewFromNextTokens()) + if err != nil { + return err + } + sel, ok := unm.(Selector) + if !ok { + return d.Errf("module %s is not a Selector", mod.Name) + } + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + h.LoadBalancing.SelectionPolicyRaw = caddyconfig.JSONModuleObject(sel, "policy", name, nil) + + case "lb_try_duration": + if !d.NextArg() { + return d.ArgErr() + } + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad duration value %s: %v", d.Val(), err) + } + h.LoadBalancing.TryDuration = caddy.Duration(dur) + + case "lb_try_interval": + if !d.NextArg() { + return d.ArgErr() + } + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad interval value '%s': %v", d.Val(), err) + } + h.LoadBalancing.TryInterval = caddy.Duration(dur) + + case "health_path": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + h.HealthChecks.Active.Path = d.Val() + + case "health_port": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + portNum, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("bad port number '%s': %v", d.Val(), err) + } + h.HealthChecks.Active.Port = portNum + + case "health_interval": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad interval value %s: %v", d.Val(), err) + } + h.HealthChecks.Active.Interval = caddy.Duration(dur) + + case "health_timeout": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad timeout value %s: %v", d.Val(), err) + } + h.HealthChecks.Active.Timeout = caddy.Duration(dur) + + case "health_status": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + val := d.Val() + if len(val) == 3 && strings.HasSuffix(val, "xx") { + val = val[:1] + } + statusNum, err := strconv.Atoi(val[:1]) + if err != nil { + return d.Errf("bad status value '%s': %v", d.Val(), err) + } + h.HealthChecks.Active.ExpectStatus = statusNum + + case "health_body": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + h.HealthChecks.Active.ExpectBody = d.Val() + + case "max_fails": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + maxFails, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("invalid maximum fail count '%s': %v", d.Val(), err) + } + h.HealthChecks.Passive.MaxFails = maxFails + + case "fail_duration": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad duration value '%s': %v", d.Val(), err) + } + h.HealthChecks.Passive.FailDuration = caddy.Duration(dur) + + case "unhealthy_request_count": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + maxConns, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("invalid maximum connection count '%s': %v", d.Val(), err) + } + h.HealthChecks.Passive.UnhealthyRequestCount = maxConns + + case "unhealthy_status": + args := d.RemainingArgs() + if len(args) == 0 { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + for _, arg := range args { + if len(arg) == 3 && strings.HasSuffix(arg, "xx") { + arg = arg[:1] + } + statusNum, err := strconv.Atoi(arg[:1]) + if err != nil { + return d.Errf("bad status value '%s': %v", d.Val(), err) + } + h.HealthChecks.Passive.UnhealthyStatus = append(h.HealthChecks.Passive.UnhealthyStatus, statusNum) + } + + case "unhealthy_latency": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad duration value '%s': %v", d.Val(), err) + } + h.HealthChecks.Passive.UnhealthyLatency = caddy.Duration(dur) + + case "transport": + if !d.NextArg() { + return d.ArgErr() + } + if h.TransportRaw != nil { + return d.Err("transport already specified") + } + name := d.Val() + mod, err := caddy.GetModule("http.handlers.reverse_proxy.transport." + name) + if err != nil { + return d.Errf("getting transport module '%s': %v", mod.Name, err) + } + unm, ok := mod.New().(caddyfile.Unmarshaler) + if !ok { + return d.Errf("transport module '%s' is not a Caddyfile unmarshaler", mod.Name) + } + d.Next() // consume the module name token + err = unm.UnmarshalCaddyfile(d.NewFromNextTokens()) + if err != nil { + return err + } + rt, ok := unm.(http.RoundTripper) + if !ok { + return d.Errf("module %s is not a RoundTripper", mod.Name) + } + h.TransportRaw = caddyconfig.JSONModuleObject(rt, "protocol", name, nil) + + default: + return d.Errf("unrecognized subdirective %s", d.Val()) + } + } + } + + return nil +} + +// UnmarshalCaddyfile deserializes Caddyfile tokens into h. +// +// transport http { +// read_buffer <size> +// write_buffer <size> +// dial_timeout <duration> +// tls_client_auth <cert_file> <key_file> +// tls_insecure_skip_verify +// tls_timeout <duration> +// keepalive [off|<duration>] +// keepalive_idle_conns <max_count> +// } +// +func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + for d.NextBlock() { + switch d.Val() { + case "read_buffer": + if !d.NextArg() { + return d.ArgErr() + } + size, err := humanize.ParseBytes(d.Val()) + if err != nil { + return d.Errf("invalid read buffer size '%s': %v", d.Val(), err) + } + h.ReadBufferSize = int(size) + + case "write_buffer": + if !d.NextArg() { + return d.ArgErr() + } + size, err := humanize.ParseBytes(d.Val()) + if err != nil { + return d.Errf("invalid write buffer size '%s': %v", d.Val(), err) + } + h.WriteBufferSize = int(size) + + case "dial_timeout": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad timeout value '%s': %v", d.Val(), err) + } + h.DialTimeout = caddy.Duration(dur) + + case "tls_client_auth": + args := d.RemainingArgs() + if len(args) != 2 { + return d.ArgErr() + } + if h.TLS == nil { + h.TLS = new(TLSConfig) + } + h.TLS.ClientCertificateFile = args[0] + h.TLS.ClientCertificateKeyFile = args[1] + + case "tls_insecure_skip_verify": + if d.NextArg() { + return d.ArgErr() + } + if h.TLS == nil { + h.TLS = new(TLSConfig) + } + h.TLS.InsecureSkipVerify = true + + case "tls_timeout": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad timeout value '%s': %v", d.Val(), err) + } + if h.TLS == nil { + h.TLS = new(TLSConfig) + } + h.TLS.HandshakeTimeout = caddy.Duration(dur) + + case "keepalive": + if !d.NextArg() { + return d.ArgErr() + } + if h.KeepAlive == nil { + h.KeepAlive = new(KeepAlive) + } + if d.Val() == "off" { + var disable bool + h.KeepAlive.Enabled = &disable + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad duration value '%s': %v", d.Val(), err) + } + h.KeepAlive.IdleConnTimeout = caddy.Duration(dur) + + case "keepalive_idle_conns": + if !d.NextArg() { + return d.ArgErr() + } + num, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("bad integer value '%s': %v", d.Val(), err) + } + if h.KeepAlive == nil { + h.KeepAlive = new(KeepAlive) + } + h.KeepAlive.MaxIdleConns = num + + default: + return d.Errf("unrecognized subdirective %s", d.Val()) + } + } + return nil +} + +// Interface guards +var ( + _ caddyfile.Unmarshaler = (*Handler)(nil) + _ caddyfile.Unmarshaler = (*HTTPTransport)(nil) +) diff --git a/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go b/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go new file mode 100644 index 0000000..c8b9f63 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go @@ -0,0 +1,54 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fastcgi + +import "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" + +// UnmarshalCaddyfile deserializes Caddyfile tokens into h. +// +// transport fastcgi { +// root <path> +// split <at> +// env <key> <value> +// } +// +func (t *Transport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + for d.NextBlock() { + switch d.Val() { + case "root": + if !d.NextArg() { + return d.ArgErr() + } + t.Root = d.Val() + + case "split": + if !d.NextArg() { + return d.ArgErr() + } + t.SplitPath = d.Val() + + case "env": + args := d.RemainingArgs() + if len(args) != 2 { + return d.ArgErr() + } + t.EnvVars = append(t.EnvVars, [2]string{args[0], args[1]}) + + default: + return d.Errf("unrecognized subdirective %s", d.Val()) + } + } + return nil +} diff --git a/modules/caddyhttp/reverseproxy/fastcgi/client.go b/modules/caddyhttp/reverseproxy/fastcgi/client.go new file mode 100644 index 0000000..ae0de00 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/fastcgi/client.go @@ -0,0 +1,578 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Forked Jan. 2015 from http://bitbucket.org/PinIdea/fcgi_client +// (which is forked from https://code.google.com/p/go-fastcgi-client/). +// This fork contains several fixes and improvements by Matt Holt and +// other contributors to the Caddy project. + +// Copyright 2012 Junqing Tan <ivan@mysqlab.net> and The Go Authors +// Use of this source code is governed by a BSD-style +// Part of source code is from Go fcgi package + +package fastcgi + +import ( + "bufio" + "bytes" + "context" + "encoding/binary" + "errors" + "io" + "io/ioutil" + "mime/multipart" + "net" + "net/http" + "net/http/httputil" + "net/textproto" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" +) + +// FCGIListenSockFileno describes listen socket file number. +const FCGIListenSockFileno uint8 = 0 + +// FCGIHeaderLen describes header length. +const FCGIHeaderLen uint8 = 8 + +// Version1 describes the version. +const Version1 uint8 = 1 + +// FCGINullRequestID describes the null request ID. +const FCGINullRequestID uint8 = 0 + +// FCGIKeepConn describes keep connection mode. +const FCGIKeepConn uint8 = 1 + +const ( + // BeginRequest is the begin request flag. + BeginRequest uint8 = iota + 1 + // AbortRequest is the abort request flag. + AbortRequest + // EndRequest is the end request flag. + EndRequest + // Params is the parameters flag. + Params + // Stdin is the standard input flag. + Stdin + // Stdout is the standard output flag. + Stdout + // Stderr is the standard error flag. + Stderr + // Data is the data flag. + Data + // GetValues is the get values flag. + GetValues + // GetValuesResult is the get values result flag. + GetValuesResult + // UnknownType is the unknown type flag. + UnknownType + // MaxType is the maximum type flag. + MaxType = UnknownType +) + +const ( + // Responder is the responder flag. + Responder uint8 = iota + 1 + // Authorizer is the authorizer flag. + Authorizer + // Filter is the filter flag. + Filter +) + +const ( + // RequestComplete is the completed request flag. + RequestComplete uint8 = iota + // CantMultiplexConns is the multiplexed connections flag. + CantMultiplexConns + // Overloaded is the overloaded flag. + Overloaded + // UnknownRole is the unknown role flag. + UnknownRole +) + +const ( + // MaxConns is the maximum connections flag. + MaxConns string = "MAX_CONNS" + // MaxRequests is the maximum requests flag. + MaxRequests string = "MAX_REQS" + // MultiplexConns is the multiplex connections flag. + MultiplexConns string = "MPXS_CONNS" +) + +const ( + maxWrite = 65500 // 65530 may work, but for compatibility + maxPad = 255 +) + +type header struct { + Version uint8 + Type uint8 + ID uint16 + ContentLength uint16 + PaddingLength uint8 + Reserved uint8 +} + +// for padding so we don't have to allocate all the time +// not synchronized because we don't care what the contents are +var pad [maxPad]byte + +func (h *header) init(recType uint8, reqID uint16, contentLength int) { + h.Version = 1 + h.Type = recType + h.ID = reqID + h.ContentLength = uint16(contentLength) + h.PaddingLength = uint8(-contentLength & 7) +} + +type record struct { + h header + rbuf []byte +} + +func (rec *record) read(r io.Reader) (buf []byte, err error) { + if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil { + return + } + if rec.h.Version != 1 { + err = errors.New("fcgi: invalid header version") + return + } + if rec.h.Type == EndRequest { + err = io.EOF + return + } + n := int(rec.h.ContentLength) + int(rec.h.PaddingLength) + if len(rec.rbuf) < n { + rec.rbuf = make([]byte, n) + } + if _, err = io.ReadFull(r, rec.rbuf[:n]); err != nil { + return + } + buf = rec.rbuf[:int(rec.h.ContentLength)] + + return +} + +// FCGIClient implements a FastCGI client, which is a standard for +// interfacing external applications with Web servers. +type FCGIClient struct { + mutex sync.Mutex + rwc io.ReadWriteCloser + h header + buf bytes.Buffer + stderr bytes.Buffer + keepAlive bool + reqID uint16 +} + +// DialWithDialerContext connects to the fcgi responder at the specified network address, using custom net.Dialer +// and a context. +// See func net.Dial for a description of the network and address parameters. +func DialWithDialerContext(ctx context.Context, network, address string, dialer net.Dialer) (fcgi *FCGIClient, err error) { + var conn net.Conn + conn, err = dialer.DialContext(ctx, network, address) + if err != nil { + return + } + + fcgi = &FCGIClient{ + rwc: conn, + keepAlive: false, + reqID: 1, + } + + return +} + +// DialContext is like Dial but passes ctx to dialer.Dial. +func DialContext(ctx context.Context, network, address string) (fcgi *FCGIClient, err error) { + // TODO: why not set timeout here? + return DialWithDialerContext(ctx, network, address, net.Dialer{}) +} + +// Dial connects to the fcgi responder at the specified network address, using default net.Dialer. +// See func net.Dial for a description of the network and address parameters. +func Dial(network, address string) (fcgi *FCGIClient, err error) { + return DialContext(context.Background(), network, address) +} + +// Close closes fcgi connection +func (c *FCGIClient) Close() { + c.rwc.Close() +} + +func (c *FCGIClient) writeRecord(recType uint8, content []byte) (err error) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.buf.Reset() + c.h.init(recType, c.reqID, len(content)) + if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil { + return err + } + if _, err := c.buf.Write(content); err != nil { + return err + } + if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil { + return err + } + _, err = c.rwc.Write(c.buf.Bytes()) + return err +} + +func (c *FCGIClient) writeBeginRequest(role uint16, flags uint8) error { + b := [8]byte{byte(role >> 8), byte(role), flags} + return c.writeRecord(BeginRequest, b[:]) +} + +func (c *FCGIClient) writeEndRequest(appStatus int, protocolStatus uint8) error { + b := make([]byte, 8) + binary.BigEndian.PutUint32(b, uint32(appStatus)) + b[4] = protocolStatus + return c.writeRecord(EndRequest, b) +} + +func (c *FCGIClient) writePairs(recType uint8, pairs map[string]string) error { + w := newWriter(c, recType) + b := make([]byte, 8) + nn := 0 + for k, v := range pairs { + m := 8 + len(k) + len(v) + if m > maxWrite { + // param data size exceed 65535 bytes" + vl := maxWrite - 8 - len(k) + v = v[:vl] + } + n := encodeSize(b, uint32(len(k))) + n += encodeSize(b[n:], uint32(len(v))) + m = n + len(k) + len(v) + if (nn + m) > maxWrite { + w.Flush() + nn = 0 + } + nn += m + if _, err := w.Write(b[:n]); err != nil { + return err + } + if _, err := w.WriteString(k); err != nil { + return err + } + if _, err := w.WriteString(v); err != nil { + return err + } + } + w.Close() + return nil +} + +func encodeSize(b []byte, size uint32) int { + if size > 127 { + size |= 1 << 31 + binary.BigEndian.PutUint32(b, size) + return 4 + } + b[0] = byte(size) + return 1 +} + +// bufWriter encapsulates bufio.Writer but also closes the underlying stream when +// Closed. +type bufWriter struct { + closer io.Closer + *bufio.Writer +} + +func (w *bufWriter) Close() error { + if err := w.Writer.Flush(); err != nil { + w.closer.Close() + return err + } + return w.closer.Close() +} + +func newWriter(c *FCGIClient, recType uint8) *bufWriter { + s := &streamWriter{c: c, recType: recType} + w := bufio.NewWriterSize(s, maxWrite) + return &bufWriter{s, w} +} + +// streamWriter abstracts out the separation of a stream into discrete records. +// It only writes maxWrite bytes at a time. +type streamWriter struct { + c *FCGIClient + recType uint8 +} + +func (w *streamWriter) Write(p []byte) (int, error) { + nn := 0 + for len(p) > 0 { + n := len(p) + if n > maxWrite { + n = maxWrite + } + if err := w.c.writeRecord(w.recType, p[:n]); err != nil { + return nn, err + } + nn += n + p = p[n:] + } + return nn, nil +} + +func (w *streamWriter) Close() error { + // send empty record to close the stream + return w.c.writeRecord(w.recType, nil) +} + +type streamReader struct { + c *FCGIClient + buf []byte +} + +func (w *streamReader) Read(p []byte) (n int, err error) { + + if len(p) > 0 { + if len(w.buf) == 0 { + + // filter outputs for error log + for { + rec := &record{} + var buf []byte + buf, err = rec.read(w.c.rwc) + if err != nil { + return + } + // standard error output + if rec.h.Type == Stderr { + w.c.stderr.Write(buf) + continue + } + w.buf = buf + break + } + } + + n = len(p) + if n > len(w.buf) { + n = len(w.buf) + } + copy(p, w.buf[:n]) + w.buf = w.buf[n:] + } + + return +} + +// Do made the request and returns a io.Reader that translates the data read +// from fcgi responder out of fcgi packet before returning it. +func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err error) { + err = c.writeBeginRequest(uint16(Responder), 0) + if err != nil { + return + } + + err = c.writePairs(Params, p) + if err != nil { + return + } + + body := newWriter(c, Stdin) + if req != nil { + _, _ = io.Copy(body, req) + } + body.Close() + + r = &streamReader{c: c} + return +} + +// clientCloser is a io.ReadCloser. It wraps a io.Reader with a Closer +// that closes FCGIClient connection. +type clientCloser struct { + *FCGIClient + io.Reader +} + +func (f clientCloser) Close() error { return f.rwc.Close() } + +// Request returns a HTTP Response with Header and Body +// from fcgi responder +func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Response, err error) { + r, err := c.Do(p, req) + if err != nil { + return + } + + rb := bufio.NewReader(r) + tp := textproto.NewReader(rb) + resp = new(http.Response) + + // Parse the response headers. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil && err != io.EOF { + return + } + resp.Header = http.Header(mimeHeader) + + if resp.Header.Get("Status") != "" { + statusParts := strings.SplitN(resp.Header.Get("Status"), " ", 2) + resp.StatusCode, err = strconv.Atoi(statusParts[0]) + if err != nil { + return + } + if len(statusParts) > 1 { + resp.Status = statusParts[1] + } + + } else { + resp.StatusCode = http.StatusOK + } + + // TODO: fixTransferEncoding ? + resp.TransferEncoding = resp.Header["Transfer-Encoding"] + resp.ContentLength, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + + if chunked(resp.TransferEncoding) { + resp.Body = clientCloser{c, httputil.NewChunkedReader(rb)} + } else { + resp.Body = clientCloser{c, ioutil.NopCloser(rb)} + } + return +} + +// Get issues a GET request to the fcgi responder. +func (c *FCGIClient) Get(p map[string]string, body io.Reader, l int64) (resp *http.Response, err error) { + + p["REQUEST_METHOD"] = "GET" + p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10) + + return c.Request(p, body) +} + +// Head issues a HEAD request to the fcgi responder. +func (c *FCGIClient) Head(p map[string]string) (resp *http.Response, err error) { + + p["REQUEST_METHOD"] = "HEAD" + p["CONTENT_LENGTH"] = "0" + + return c.Request(p, nil) +} + +// Options issues an OPTIONS request to the fcgi responder. +func (c *FCGIClient) Options(p map[string]string) (resp *http.Response, err error) { + + p["REQUEST_METHOD"] = "OPTIONS" + p["CONTENT_LENGTH"] = "0" + + return c.Request(p, nil) +} + +// Post issues a POST request to the fcgi responder. with request body +// in the format that bodyType specified +func (c *FCGIClient) Post(p map[string]string, method string, bodyType string, body io.Reader, l int64) (resp *http.Response, err error) { + if p == nil { + p = make(map[string]string) + } + + p["REQUEST_METHOD"] = strings.ToUpper(method) + + if len(p["REQUEST_METHOD"]) == 0 || p["REQUEST_METHOD"] == "GET" { + p["REQUEST_METHOD"] = "POST" + } + + p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10) + if len(bodyType) > 0 { + p["CONTENT_TYPE"] = bodyType + } else { + p["CONTENT_TYPE"] = "application/x-www-form-urlencoded" + } + + return c.Request(p, body) +} + +// PostForm issues a POST to the fcgi responder, with form +// as a string key to a list values (url.Values) +func (c *FCGIClient) PostForm(p map[string]string, data url.Values) (resp *http.Response, err error) { + body := bytes.NewReader([]byte(data.Encode())) + return c.Post(p, "POST", "application/x-www-form-urlencoded", body, int64(body.Len())) +} + +// PostFile issues a POST to the fcgi responder in multipart(RFC 2046) standard, +// with form as a string key to a list values (url.Values), +// and/or with file as a string key to a list file path. +func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[string]string) (resp *http.Response, err error) { + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + bodyType := writer.FormDataContentType() + + for key, val := range data { + for _, v0 := range val { + err = writer.WriteField(key, v0) + if err != nil { + return + } + } + } + + for key, val := range file { + fd, e := os.Open(val) + if e != nil { + return nil, e + } + defer fd.Close() + + part, e := writer.CreateFormFile(key, filepath.Base(val)) + if e != nil { + return nil, e + } + _, err = io.Copy(part, fd) + if err != nil { + return + } + } + + err = writer.Close() + if err != nil { + return + } + + return c.Post(p, "POST", bodyType, buf, int64(buf.Len())) +} + +// SetReadTimeout sets the read timeout for future calls that read from the +// fcgi responder. A zero value for t means no timeout will be set. +func (c *FCGIClient) SetReadTimeout(t time.Duration) error { + if conn, ok := c.rwc.(net.Conn); ok && t != 0 { + return conn.SetReadDeadline(time.Now().Add(t)) + } + return nil +} + +// SetWriteTimeout sets the write timeout for future calls that send data to +// the fcgi responder. A zero value for t means no timeout will be set. +func (c *FCGIClient) SetWriteTimeout(t time.Duration) error { + if conn, ok := c.rwc.(net.Conn); ok && t != 0 { + return conn.SetWriteDeadline(time.Now().Add(t)) + } + return nil +} + +// Checks whether chunked is part of the encodings stack +func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } diff --git a/modules/caddyhttp/reverseproxy/fastcgi/client_test.go b/modules/caddyhttp/reverseproxy/fastcgi/client_test.go new file mode 100644 index 0000000..c090f3c --- /dev/null +++ b/modules/caddyhttp/reverseproxy/fastcgi/client_test.go @@ -0,0 +1,301 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// NOTE: These tests were adapted from the original +// repository from which this package was forked. +// The tests are slow (~10s) and in dire need of rewriting. +// As such, the tests have been disabled to speed up +// automated builds until they can be properly written. + +package fastcgi + +import ( + "bytes" + "crypto/md5" + "encoding/binary" + "fmt" + "io" + "io/ioutil" + "log" + "math/rand" + "net" + "net/http" + "net/http/fcgi" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + "time" +) + +// test fcgi protocol includes: +// Get, Post, Post in multipart/form-data, and Post with files +// each key should be the md5 of the value or the file uploaded +// specify remote fcgi responder ip:port to test with php +// test failed if the remote fcgi(script) failed md5 verification +// and output "FAILED" in response +const ( + scriptFile = "/tank/www/fcgic_test.php" + //ipPort = "remote-php-serv:59000" + ipPort = "127.0.0.1:59000" +) + +var globalt *testing.T + +type FastCGIServer struct{} + +func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + + if err := req.ParseMultipartForm(100000000); err != nil { + log.Printf("[ERROR] failed to parse: %v", err) + } + + stat := "PASSED" + fmt.Fprintln(resp, "-") + fileNum := 0 + { + length := 0 + for k0, v0 := range req.Form { + h := md5.New() + _, _ = io.WriteString(h, v0[0]) + _md5 := fmt.Sprintf("%x", h.Sum(nil)) + + length += len(k0) + length += len(v0[0]) + + // echo error when key != _md5(val) + if _md5 != k0 { + fmt.Fprintln(resp, "server:err ", _md5, k0) + stat = "FAILED" + } + } + if req.MultipartForm != nil { + fileNum = len(req.MultipartForm.File) + for kn, fns := range req.MultipartForm.File { + //fmt.Fprintln(resp, "server:filekey ", kn ) + length += len(kn) + for _, f := range fns { + fd, err := f.Open() + if err != nil { + log.Println("server:", err) + return + } + h := md5.New() + l0, err := io.Copy(h, fd) + if err != nil { + log.Println(err) + return + } + length += int(l0) + defer fd.Close() + md5 := fmt.Sprintf("%x", h.Sum(nil)) + //fmt.Fprintln(resp, "server:filemd5 ", md5 ) + + if kn != md5 { + fmt.Fprintln(resp, "server:err ", md5, kn) + stat = "FAILED" + } + //fmt.Fprintln(resp, "server:filename ", f.Filename ) + } + } + } + + fmt.Fprintln(resp, "server:got data length", length) + } + fmt.Fprintln(resp, "-"+stat+"-POST(", len(req.Form), ")-FILE(", fileNum, ")--") +} + +func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) { + fcgi, err := Dial("tcp", ipPort) + if err != nil { + log.Println("err:", err) + return + } + + length := 0 + + var resp *http.Response + switch reqType { + case 0: + if len(data) > 0 { + length = len(data) + rd := bytes.NewReader(data) + resp, err = fcgi.Post(fcgiParams, "", "", rd, int64(rd.Len())) + } else if len(posts) > 0 { + values := url.Values{} + for k, v := range posts { + values.Set(k, v) + length += len(k) + 2 + len(v) + } + resp, err = fcgi.PostForm(fcgiParams, values) + } else { + rd := bytes.NewReader(data) + resp, err = fcgi.Get(fcgiParams, rd, int64(rd.Len())) + } + + default: + values := url.Values{} + for k, v := range posts { + values.Set(k, v) + length += len(k) + 2 + len(v) + } + + for k, v := range files { + fi, _ := os.Lstat(v) + length += len(k) + int(fi.Size()) + } + resp, err = fcgi.PostFile(fcgiParams, values, files) + } + + if err != nil { + log.Println("err:", err) + return + } + + defer resp.Body.Close() + content, _ = ioutil.ReadAll(resp.Body) + + log.Println("c: send data length ≈", length, string(content)) + fcgi.Close() + time.Sleep(1 * time.Second) + + if bytes.Contains(content, []byte("FAILED")) { + globalt.Error("Server return failed message") + } + + return +} + +func generateRandFile(size int) (p string, m string) { + + p = filepath.Join(os.TempDir(), "fcgict"+strconv.Itoa(rand.Int())) + + // open output file + fo, err := os.Create(p) + if err != nil { + panic(err) + } + // close fo on exit and check for its returned error + defer func() { + if err := fo.Close(); err != nil { + panic(err) + } + }() + + h := md5.New() + for i := 0; i < size/16; i++ { + buf := make([]byte, 16) + binary.PutVarint(buf, rand.Int63()) + if _, err := fo.Write(buf); err != nil { + log.Printf("[ERROR] failed to write buffer: %v\n", err) + } + if _, err := h.Write(buf); err != nil { + log.Printf("[ERROR] failed to write buffer: %v\n", err) + } + } + m = fmt.Sprintf("%x", h.Sum(nil)) + return +} + +func DisabledTest(t *testing.T) { + // TODO: test chunked reader + globalt = t + + rand.Seed(time.Now().UTC().UnixNano()) + + // server + go func() { + listener, err := net.Listen("tcp", ipPort) + if err != nil { + log.Println("listener creation failed: ", err) + } + + srv := new(FastCGIServer) + if err := fcgi.Serve(listener, srv); err != nil { + log.Print("[ERROR] failed to start server: ", err) + } + }() + + time.Sleep(1 * time.Second) + + // init + fcgiParams := make(map[string]string) + fcgiParams["REQUEST_METHOD"] = "GET" + fcgiParams["SERVER_PROTOCOL"] = "HTTP/1.1" + //fcgi_params["GATEWAY_INTERFACE"] = "CGI/1.1" + fcgiParams["SCRIPT_FILENAME"] = scriptFile + + // simple GET + log.Println("test:", "get") + sendFcgi(0, fcgiParams, nil, nil, nil) + + // simple post data + log.Println("test:", "post") + sendFcgi(0, fcgiParams, []byte("c4ca4238a0b923820dcc509a6f75849b=1&7b8b965ad4bca0e41ab51de7b31363a1=n"), nil, nil) + + log.Println("test:", "post data (more than 60KB)") + data := "" + for i := 0x00; i < 0xff; i++ { + v0 := strings.Repeat(string(i), 256) + h := md5.New() + _, _ = io.WriteString(h, v0) + k0 := fmt.Sprintf("%x", h.Sum(nil)) + data += k0 + "=" + url.QueryEscape(v0) + "&" + } + sendFcgi(0, fcgiParams, []byte(data), nil, nil) + + log.Println("test:", "post form (use url.Values)") + p0 := make(map[string]string, 1) + p0["c4ca4238a0b923820dcc509a6f75849b"] = "1" + p0["7b8b965ad4bca0e41ab51de7b31363a1"] = "n" + sendFcgi(1, fcgiParams, nil, p0, nil) + + log.Println("test:", "post forms (256 keys, more than 1MB)") + p1 := make(map[string]string, 1) + for i := 0x00; i < 0xff; i++ { + v0 := strings.Repeat(string(i), 4096) + h := md5.New() + _, _ = io.WriteString(h, v0) + k0 := fmt.Sprintf("%x", h.Sum(nil)) + p1[k0] = v0 + } + sendFcgi(1, fcgiParams, nil, p1, nil) + + log.Println("test:", "post file (1 file, 500KB)) ") + f0 := make(map[string]string, 1) + path0, m0 := generateRandFile(500000) + f0[m0] = path0 + sendFcgi(1, fcgiParams, nil, p1, f0) + + log.Println("test:", "post multiple files (2 files, 5M each) and forms (256 keys, more than 1MB data") + path1, m1 := generateRandFile(5000000) + f0[m1] = path1 + sendFcgi(1, fcgiParams, nil, p1, f0) + + log.Println("test:", "post only files (2 files, 5M each)") + sendFcgi(1, fcgiParams, nil, nil, f0) + + log.Println("test:", "post only 1 file") + delete(f0, "m0") + sendFcgi(1, fcgiParams, nil, nil, f0) + + if err := os.Remove(path0); err != nil { + log.Println("[ERROR] failed to remove path: ", err) + } + if err := os.Remove(path1); err != nil { + log.Println("[ERROR] failed to remove path: ", err) + } +} diff --git a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go new file mode 100644 index 0000000..91039c9 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go @@ -0,0 +1,301 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fastcgi + +import ( + "context" + "crypto/tls" + "fmt" + "net/http" + "net/url" + "path" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/caddyserver/caddy/v2/modules/caddyhttp" + "github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy" + "github.com/caddyserver/caddy/v2/modules/caddytls" + + "github.com/caddyserver/caddy/v2" +) + +func init() { + caddy.RegisterModule(Transport{}) +} + +// Transport facilitates FastCGI communication. +type Transport struct { + // TODO: Populate these + softwareName string + softwareVersion string + serverName string + serverPort string + + // Use this directory as the fastcgi root directory. Defaults to the root + // directory of the parent virtual host. + Root string `json:"root,omitempty"` + + // The path in the URL will be split into two, with the first piece ending + // with the value of SplitPath. The first piece will be assumed as the + // actual resource (CGI script) name, and the second piece will be set to + // PATH_INFO for the CGI script to use. + SplitPath string `json:"split_path,omitempty"` + + // Environment variables (TODO: make a map of string to value...?) + EnvVars [][2]string `json:"env,omitempty"` + + // The duration used to set a deadline when connecting to an upstream. + DialTimeout caddy.Duration `json:"dial_timeout,omitempty"` + + // The duration used to set a deadline when reading from the FastCGI server. + ReadTimeout caddy.Duration `json:"read_timeout,omitempty"` + + // The duration used to set a deadline when sending to the FastCGI server. + WriteTimeout caddy.Duration `json:"write_timeout,omitempty"` +} + +// CaddyModule returns the Caddy module information. +func (Transport) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.transport.fastcgi", + New: func() caddy.Module { return new(Transport) }, + } +} + +// Provision sets up t. +func (t *Transport) Provision(_ caddy.Context) error { + if t.Root == "" { + t.Root = "{http.vars.root}" + } + return nil +} + +// RoundTrip implements http.RoundTripper. +func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) { + env, err := t.buildEnv(r) + if err != nil { + return nil, fmt.Errorf("building environment: %v", err) + } + + // TODO: doesn't dialer have a Timeout field? + ctx := r.Context() + if t.DialTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(t.DialTimeout)) + defer cancel() + } + + // extract dial information from request (this + // should embedded by the reverse proxy) + network, address := "tcp", r.URL.Host + if dialInfoVal := ctx.Value(reverseproxy.DialInfoCtxKey); dialInfoVal != nil { + dialInfo := dialInfoVal.(reverseproxy.DialInfo) + network = dialInfo.Network + address = dialInfo.Address + } + + fcgiBackend, err := DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("dialing backend: %v", err) + } + // fcgiBackend gets closed when response body is closed (see clientCloser) + + // read/write timeouts + if err := fcgiBackend.SetReadTimeout(time.Duration(t.ReadTimeout)); err != nil { + return nil, fmt.Errorf("setting read timeout: %v", err) + } + if err := fcgiBackend.SetWriteTimeout(time.Duration(t.WriteTimeout)); err != nil { + return nil, fmt.Errorf("setting write timeout: %v", err) + } + + contentLength := r.ContentLength + if contentLength == 0 { + contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64) + } + + var resp *http.Response + switch r.Method { + case http.MethodHead: + resp, err = fcgiBackend.Head(env) + case http.MethodGet: + resp, err = fcgiBackend.Get(env, r.Body, contentLength) + case http.MethodOptions: + resp, err = fcgiBackend.Options(env) + default: + resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength) + } + + return resp, err +} + +// buildEnv returns a set of CGI environment variables for the request. +func (t Transport) buildEnv(r *http.Request) (map[string]string, error) { + repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) + + var env map[string]string + + // Separate remote IP and port; more lenient than net.SplitHostPort + var ip, port string + if idx := strings.LastIndex(r.RemoteAddr, ":"); idx > -1 { + ip = r.RemoteAddr[:idx] + port = r.RemoteAddr[idx+1:] + } else { + ip = r.RemoteAddr + } + + // Remove [] from IPv6 addresses + ip = strings.Replace(ip, "[", "", 1) + ip = strings.Replace(ip, "]", "", 1) + + root := repl.ReplaceAll(t.Root, ".") + fpath := r.URL.Path + + // Split path in preparation for env variables. + // Previous canSplit checks ensure this can never be -1. + // TODO: I haven't brought over canSplit; make sure this doesn't break + splitPos := t.splitPos(fpath) + + // Request has the extension; path was split successfully + docURI := fpath[:splitPos+len(t.SplitPath)] + pathInfo := fpath[splitPos+len(t.SplitPath):] + scriptName := fpath + + // Strip PATH_INFO from SCRIPT_NAME + scriptName = strings.TrimSuffix(scriptName, pathInfo) + + // SCRIPT_FILENAME is the absolute path of SCRIPT_NAME + scriptFilename := filepath.Join(root, scriptName) + + // Add vhost path prefix to scriptName. Otherwise, some PHP software will + // have difficulty discovering its URL. + pathPrefix, _ := r.Context().Value(caddy.CtxKey("path_prefix")).(string) + scriptName = path.Join(pathPrefix, scriptName) + + // Get the request URL from context. The context stores the original URL in case + // it was changed by a middleware such as rewrite. By default, we pass the + // original URI in as the value of REQUEST_URI (the user can overwrite this + // if desired). Most PHP apps seem to want the original URI. Besides, this is + // how nginx defaults: http://stackoverflow.com/a/12485156/1048862 + reqURL, ok := r.Context().Value(caddyhttp.OriginalURLCtxKey).(url.URL) + if !ok { + // some requests, like active health checks, don't add this to + // the request context, so we can just use the current URL + reqURL = *r.URL + } + + requestScheme := "http" + if r.TLS != nil { + requestScheme = "https" + } + + // Some variables are unused but cleared explicitly to prevent + // the parent environment from interfering. + env = map[string]string{ + // Variables defined in CGI 1.1 spec + "AUTH_TYPE": "", // Not used + "CONTENT_LENGTH": r.Header.Get("Content-Length"), + "CONTENT_TYPE": r.Header.Get("Content-Type"), + "GATEWAY_INTERFACE": "CGI/1.1", + "PATH_INFO": pathInfo, + "QUERY_STRING": r.URL.RawQuery, + "REMOTE_ADDR": ip, + "REMOTE_HOST": ip, // For speed, remote host lookups disabled + "REMOTE_PORT": port, + "REMOTE_IDENT": "", // Not used + "REMOTE_USER": "", // TODO: once there are authentication handlers, populate this + "REQUEST_METHOD": r.Method, + "REQUEST_SCHEME": requestScheme, + "SERVER_NAME": t.serverName, + "SERVER_PORT": t.serverPort, + "SERVER_PROTOCOL": r.Proto, + "SERVER_SOFTWARE": t.softwareName + "/" + t.softwareVersion, + + // Other variables + "DOCUMENT_ROOT": root, + "DOCUMENT_URI": docURI, + "HTTP_HOST": r.Host, // added here, since not always part of headers + "REQUEST_URI": reqURL.RequestURI(), + "SCRIPT_FILENAME": scriptFilename, + "SCRIPT_NAME": scriptName, + } + + // compliance with the CGI specification requires that + // PATH_TRANSLATED should only exist if PATH_INFO is defined. + // Info: https://www.ietf.org/rfc/rfc3875 Page 14 + if env["PATH_INFO"] != "" { + env["PATH_TRANSLATED"] = filepath.Join(root, pathInfo) // Info: http://www.oreilly.com/openbook/cgi/ch02_04.html + } + + // Some web apps rely on knowing HTTPS or not + if r.TLS != nil { + env["HTTPS"] = "on" + // and pass the protocol details in a manner compatible with apache's mod_ssl + // (which is why these have a SSL_ prefix and not TLS_). + v, ok := tlsProtocolStrings[r.TLS.Version] + if ok { + env["SSL_PROTOCOL"] = v + } + // and pass the cipher suite in a manner compatible with apache's mod_ssl + for k, v := range caddytls.SupportedCipherSuites { + if v == r.TLS.CipherSuite { + env["SSL_CIPHER"] = k + break + } + } + } + + // Add env variables from config (with support for placeholders in values) + for _, envVar := range t.EnvVars { + env[envVar[0]] = repl.ReplaceAll(envVar[1], "") + } + + // Add all HTTP headers to env variables + for field, val := range r.Header { + header := strings.ToUpper(field) + header = headerNameReplacer.Replace(header) + env["HTTP_"+header] = strings.Join(val, ", ") + } + return env, nil +} + +// splitPos returns the index where path should +// be split based on t.SplitPath. +func (t Transport) splitPos(path string) int { + // TODO: + // if httpserver.CaseSensitivePath { + // return strings.Index(path, r.SplitPath) + // } + return strings.Index(strings.ToLower(path), strings.ToLower(t.SplitPath)) +} + +// TODO: +// Map of supported protocols to Apache ssl_mod format +// Note that these are slightly different from SupportedProtocols in caddytls/config.go +var tlsProtocolStrings = map[uint16]string{ + tls.VersionTLS10: "TLSv1", + tls.VersionTLS11: "TLSv1.1", + tls.VersionTLS12: "TLSv1.2", + tls.VersionTLS13: "TLSv1.3", +} + +var headerNameReplacer = strings.NewReplacer(" ", "_", "-", "_") + +// Interface guards +var ( + _ caddy.Provisioner = (*Transport)(nil) + _ http.RoundTripper = (*Transport)(nil) +) diff --git a/modules/caddyhttp/reverseproxy/healthchecker.go b/modules/caddyhttp/reverseproxy/healthchecker.go deleted file mode 100755 index c557d3f..0000000 --- a/modules/caddyhttp/reverseproxy/healthchecker.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2015 Matthew Holt and The Caddy Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package reverseproxy - -import ( - "net/http" - "time" -) - -// Upstream represents the interface that must be satisfied to use the healthchecker. -type Upstream interface { - SetHealthiness(bool) -} - -// HealthChecker represents a worker that periodically evaluates if proxy upstream host is healthy. -type HealthChecker struct { - upstream Upstream - Ticker *time.Ticker - HTTPClient *http.Client - StopChan chan bool -} - -// ScheduleChecks periodically runs health checks against an upstream host. -func (h *HealthChecker) ScheduleChecks(url string) { - // check if a host is healthy on start vs waiting for timer - h.upstream.SetHealthiness(h.IsHealthy(url)) - stop := make(chan bool) - h.StopChan = stop - - go func() { - for { - select { - case <-h.Ticker.C: - h.upstream.SetHealthiness(h.IsHealthy(url)) - case <-stop: - return - } - } - }() -} - -// Stop stops the healthchecker from makeing further requests. -func (h *HealthChecker) Stop() { - h.Ticker.Stop() - close(h.StopChan) -} - -// IsHealthy attempts to check if a upstream host is healthy. -func (h *HealthChecker) IsHealthy(url string) bool { - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return false - } - - resp, err := h.HTTPClient.Do(req) - if err != nil { - return false - } - - if resp.StatusCode < 200 || resp.StatusCode >= 400 { - return false - } - - return true -} - -// NewHealthCheckWorker returns a new instance of a HealthChecker. -func NewHealthCheckWorker(u Upstream, interval time.Duration, client *http.Client) *HealthChecker { - return &HealthChecker{ - upstream: u, - Ticker: time.NewTicker(interval), - HTTPClient: client, - } -} diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go new file mode 100644 index 0000000..abe0f9c --- /dev/null +++ b/modules/caddyhttp/reverseproxy/healthchecks.go @@ -0,0 +1,270 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "net/url" + "regexp" + "strconv" + "time" + + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" +) + +// HealthChecks holds configuration related to health checking. +type HealthChecks struct { + Active *ActiveHealthChecks `json:"active,omitempty"` + Passive *PassiveHealthChecks `json:"passive,omitempty"` +} + +// ActiveHealthChecks holds configuration related to active +// health checks (that is, health checks which occur in a +// background goroutine independently). +type ActiveHealthChecks struct { + Path string `json:"path,omitempty"` + Port int `json:"port,omitempty"` + Interval caddy.Duration `json:"interval,omitempty"` + Timeout caddy.Duration `json:"timeout,omitempty"` + MaxSize int64 `json:"max_size,omitempty"` + ExpectStatus int `json:"expect_status,omitempty"` + ExpectBody string `json:"expect_body,omitempty"` + + stopChan chan struct{} + httpClient *http.Client + bodyRegexp *regexp.Regexp +} + +// PassiveHealthChecks holds configuration related to passive +// health checks (that is, health checks which occur during +// the normal flow of request proxying). +type PassiveHealthChecks struct { + MaxFails int `json:"max_fails,omitempty"` + FailDuration caddy.Duration `json:"fail_duration,omitempty"` + UnhealthyRequestCount int `json:"unhealthy_request_count,omitempty"` + UnhealthyStatus []int `json:"unhealthy_status,omitempty"` + UnhealthyLatency caddy.Duration `json:"unhealthy_latency,omitempty"` +} + +// CircuitBreaker is a type that can act as an early-warning +// system for the health checker when backends are getting +// overloaded. +type CircuitBreaker interface { + OK() bool + RecordMetric(statusCode int, latency time.Duration) +} + +// activeHealthChecker runs active health checks on a +// regular basis and blocks until +// h.HealthChecks.Active.stopChan is closed. +func (h *Handler) activeHealthChecker() { + ticker := time.NewTicker(time.Duration(h.HealthChecks.Active.Interval)) + h.doActiveHealthChecksForAllHosts() + for { + select { + case <-ticker.C: + h.doActiveHealthChecksForAllHosts() + case <-h.HealthChecks.Active.stopChan: + ticker.Stop() + return + } + } +} + +// doActiveHealthChecksForAllHosts immediately performs a +// health checks for all hosts in the global repository. +func (h *Handler) doActiveHealthChecksForAllHosts() { + hosts.Range(func(key, value interface{}) bool { + networkAddr := key.(string) + host := value.(Host) + + go func(networkAddr string, host Host) { + network, addrs, err := caddy.ParseNetworkAddress(networkAddr) + if err != nil { + log.Printf("[ERROR] reverse_proxy: active health check for host %s: bad network address: %v", networkAddr, err) + return + } + if len(addrs) != 1 { + log.Printf("[ERROR] reverse_proxy: active health check for host %s: multiple addresses (upstream must map to only one address)", networkAddr) + return + } + hostAddr := addrs[0] + if network == "unix" || network == "unixgram" || network == "unixpacket" { + // this will be used as the Host portion of a http.Request URL, and + // paths to socket files would produce an error when creating URL, + // so use a fake Host value instead + hostAddr = network + } + err = h.doActiveHealthCheck(DialInfo{network, addrs[0]}, hostAddr, host) + if err != nil { + log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", networkAddr, err) + } + }(networkAddr, host) + + // continue to iterate all hosts + return true + }) +} + +// doActiveHealthCheck performs a health check to host which +// can be reached at address hostAddr. The actual address for +// the request will be built according to active health checker +// config. The health status of the host will be updated +// according to whether it passes the health check. An error is +// returned only if the health check fails to occur or if marking +// the host's health status fails. +func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host Host) error { + // create the URL for the request that acts as a health check + scheme := "http" + if ht, ok := h.Transport.(*http.Transport); ok && ht.TLSClientConfig != nil { + // this is kind of a hacky way to know if we should use HTTPS, but whatever + scheme = "https" + } + u := &url.URL{ + Scheme: scheme, + Host: hostAddr, + Path: h.HealthChecks.Active.Path, + } + + // adjust the port, if configured to be different + if h.HealthChecks.Active.Port != 0 { + portStr := strconv.Itoa(h.HealthChecks.Active.Port) + host, _, err := net.SplitHostPort(hostAddr) + if err != nil { + host = hostAddr + } + u.Host = net.JoinHostPort(host, portStr) + } + + // attach dialing information to this request + ctx := context.Background() + ctx = context.WithValue(ctx, caddy.ReplacerCtxKey, caddy.NewReplacer()) + ctx = context.WithValue(ctx, DialInfoCtxKey, dialInfo) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return fmt.Errorf("making request: %v", err) + } + + // do the request, being careful to tame the response body + resp, err := h.HealthChecks.Active.httpClient.Do(req) + if err != nil { + log.Printf("[INFO] reverse_proxy: active health check: %s is down (HTTP request failed: %v)", hostAddr, err) + _, err2 := host.SetHealthy(false) + if err2 != nil { + return fmt.Errorf("marking unhealthy: %v", err2) + } + return nil + } + var body io.Reader = resp.Body + if h.HealthChecks.Active.MaxSize > 0 { + body = io.LimitReader(body, h.HealthChecks.Active.MaxSize) + } + defer func() { + // drain any remaining body so connection could be re-used + io.Copy(ioutil.Discard, body) + resp.Body.Close() + }() + + // if status code is outside criteria, mark down + if h.HealthChecks.Active.ExpectStatus > 0 { + if !caddyhttp.StatusCodeMatches(resp.StatusCode, h.HealthChecks.Active.ExpectStatus) { + log.Printf("[INFO] reverse_proxy: active health check: %s is down (status code %d unexpected)", hostAddr, resp.StatusCode) + _, err := host.SetHealthy(false) + if err != nil { + return fmt.Errorf("marking unhealthy: %v", err) + } + return nil + } + } else if resp.StatusCode < 200 || resp.StatusCode >= 400 { + log.Printf("[INFO] reverse_proxy: active health check: %s is down (status code %d out of tolerances)", hostAddr, resp.StatusCode) + _, err := host.SetHealthy(false) + if err != nil { + return fmt.Errorf("marking unhealthy: %v", err) + } + return nil + } + + // if body does not match regex, mark down + if h.HealthChecks.Active.bodyRegexp != nil { + bodyBytes, err := ioutil.ReadAll(body) + if err != nil { + log.Printf("[INFO] reverse_proxy: active health check: %s is down (failed to read response body)", hostAddr) + _, err := host.SetHealthy(false) + if err != nil { + return fmt.Errorf("marking unhealthy: %v", err) + } + return nil + } + if !h.HealthChecks.Active.bodyRegexp.Match(bodyBytes) { + log.Printf("[INFO] reverse_proxy: active health check: %s is down (response body failed expectations)", hostAddr) + _, err := host.SetHealthy(false) + if err != nil { + return fmt.Errorf("marking unhealthy: %v", err) + } + return nil + } + } + + // passed health check parameters, so mark as healthy + swapped, err := host.SetHealthy(true) + if swapped { + log.Printf("[INFO] reverse_proxy: active health check: %s is back up", hostAddr) + } + if err != nil { + return fmt.Errorf("marking healthy: %v", err) + } + + return nil +} + +// countFailure is used with passive health checks. It +// remembers 1 failure for upstream for the configured +// duration. If passive health checks are disabled or +// failure expiry is 0, this is a no-op. +func (h *Handler) countFailure(upstream *Upstream) { + // only count failures if passive health checking is enabled + // and if failures are configured have a non-zero expiry + if h.HealthChecks == nil || h.HealthChecks.Passive == nil { + return + } + failDuration := time.Duration(h.HealthChecks.Passive.FailDuration) + if failDuration == 0 { + return + } + + // count failure immediately + err := upstream.Host.CountFail(1) + if err != nil { + log.Printf("[ERROR] proxy: upstream %s: counting failure: %v", + upstream.dialInfo, err) + } + + // forget it later + go func(host Host, failDuration time.Duration) { + time.Sleep(failDuration) + err := host.CountFail(-1) + if err != nil { + log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v", + upstream.dialInfo, err) + } + }(upstream.Host, failDuration) +} diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go new file mode 100644 index 0000000..1c0fae3 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/hosts.go @@ -0,0 +1,193 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import ( + "fmt" + "sync/atomic" + + "github.com/caddyserver/caddy/v2" +) + +// Host represents a remote host which can be proxied to. +// Its methods must be safe for concurrent use. +type Host interface { + // NumRequests returns the numnber of requests + // currently in process with the host. + NumRequests() int + + // Fails returns the count of recent failures. + Fails() int + + // Unhealthy returns true if the backend is unhealthy. + Unhealthy() bool + + // CountRequest atomically counts the given number of + // requests as currently in process with the host. The + // count should not go below 0. + CountRequest(int) error + + // CountFail atomically counts the given number of + // failures with the host. The count should not go + // below 0. + CountFail(int) error + + // SetHealthy atomically marks the host as either + // healthy (true) or unhealthy (false). If the given + // status is the same, this should be a no-op and + // return false. It returns true if the status was + // changed; i.e. if it is now different from before. + SetHealthy(bool) (bool, error) +} + +// UpstreamPool is a collection of upstreams. +type UpstreamPool []*Upstream + +// Upstream bridges this proxy's configuration to the +// state of the backend host it is correlated with. +type Upstream struct { + Host `json:"-"` + + Dial string `json:"dial,omitempty"` + MaxRequests int `json:"max_requests,omitempty"` + + // TODO: This could be really useful, to bind requests + // with certain properties to specific backends + // HeaderAffinity string + // IPAffinity string + + healthCheckPolicy *PassiveHealthChecks + cb CircuitBreaker + dialInfo DialInfo +} + +// Available returns true if the remote host +// is available to receive requests. This is +// the method that should be used by selection +// policies, etc. to determine if a backend +// should be able to be sent a request. +func (u *Upstream) Available() bool { + return u.Healthy() && !u.Full() +} + +// Healthy returns true if the remote host +// is currently known to be healthy or "up". +// It consults the circuit breaker, if any. +func (u *Upstream) Healthy() bool { + healthy := !u.Host.Unhealthy() + if healthy && u.healthCheckPolicy != nil { + healthy = u.Host.Fails() < u.healthCheckPolicy.MaxFails + } + if healthy && u.cb != nil { + healthy = u.cb.OK() + } + return healthy +} + +// Full returns true if the remote host +// cannot receive more requests at this time. +func (u *Upstream) Full() bool { + return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests +} + +// upstreamHost is the basic, in-memory representation +// of the state of a remote host. It implements the +// Host interface. +type upstreamHost struct { + numRequests int64 // must be first field to be 64-bit aligned on 32-bit systems (see https://golang.org/pkg/sync/atomic/#pkg-note-BUG) + fails int64 + unhealthy int32 +} + +// NumRequests returns the number of active requests to the upstream. +func (uh *upstreamHost) NumRequests() int { + return int(atomic.LoadInt64(&uh.numRequests)) +} + +// Fails returns the number of recent failures with the upstream. +func (uh *upstreamHost) Fails() int { + return int(atomic.LoadInt64(&uh.fails)) +} + +// Unhealthy returns whether the upstream is healthy. +func (uh *upstreamHost) Unhealthy() bool { + return atomic.LoadInt32(&uh.unhealthy) == 1 +} + +// CountRequest mutates the active request count by +// delta. It returns an error if the adjustment fails. +func (uh *upstreamHost) CountRequest(delta int) error { + result := atomic.AddInt64(&uh.numRequests, int64(delta)) + if result < 0 { + return fmt.Errorf("count below 0: %d", result) + } + return nil +} + +// CountFail mutates the recent failures count by +// delta. It returns an error if the adjustment fails. +func (uh *upstreamHost) CountFail(delta int) error { + result := atomic.AddInt64(&uh.fails, int64(delta)) + if result < 0 { + return fmt.Errorf("count below 0: %d", result) + } + return nil +} + +// SetHealthy sets the upstream has healthy or unhealthy +// and returns true if the value was different from before, +// or an error if the adjustment failed. +func (uh *upstreamHost) SetHealthy(healthy bool) (bool, error) { + var unhealthy, compare int32 = 1, 0 + if healthy { + unhealthy, compare = 0, 1 + } + swapped := atomic.CompareAndSwapInt32(&uh.unhealthy, compare, unhealthy) + return swapped, nil +} + +// DialInfo contains information needed to dial a +// connection to an upstream host. This information +// may be different than that which is represented +// in a URL (for example, unix sockets don't have +// a host that can be represented in a URL, but +// they certainly have a network name and address). +type DialInfo struct { + // The network to use. This should be one of the + // values that is accepted by net.Dial: + // https://golang.org/pkg/net/#Dial + Network string + + // The address to dial. Follows the same + // semantics and rules as net.Dial. + Address string +} + +// String returns the Caddy network address form +// by joining the network and address with a +// forward slash. +func (di DialInfo) String() string { + return di.Network + "/" + di.Address +} + +// DialInfoCtxKey is used to store a DialInfo +// in a context.Context. +const DialInfoCtxKey = caddy.CtxKey("dial_info") + +// hosts is the global repository for hosts that are +// currently in use by active configuration(s). This +// allows the state of remote hosts to be preserved +// through config reloads. +var hosts = caddy.NewUsagePool() diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go new file mode 100644 index 0000000..c135ac8 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -0,0 +1,208 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "fmt" + "net" + "net/http" + "reflect" + "time" + + "github.com/caddyserver/caddy/v2" +) + +func init() { + caddy.RegisterModule(HTTPTransport{}) +} + +// HTTPTransport is essentially a configuration wrapper for http.Transport. +// It defines a JSON structure useful when configuring the HTTP transport +// for Caddy's reverse proxy. +type HTTPTransport struct { + // TODO: It's possible that other transports (like fastcgi) might be + // able to borrow/use at least some of these config fields; if so, + // move them into a type called CommonTransport and embed it + TLS *TLSConfig `json:"tls,omitempty"` + KeepAlive *KeepAlive `json:"keep_alive,omitempty"` + Compression *bool `json:"compression,omitempty"` + MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // TODO: NOTE: we use our health check stuff to enforce max REQUESTS per host, but this is connections + DialTimeout caddy.Duration `json:"dial_timeout,omitempty"` + FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"` + ResponseHeaderTimeout caddy.Duration `json:"response_header_timeout,omitempty"` + ExpectContinueTimeout caddy.Duration `json:"expect_continue_timeout,omitempty"` + MaxResponseHeaderSize int64 `json:"max_response_header_size,omitempty"` + WriteBufferSize int `json:"write_buffer_size,omitempty"` + ReadBufferSize int `json:"read_buffer_size,omitempty"` + + RoundTripper http.RoundTripper `json:"-"` +} + +// CaddyModule returns the Caddy module information. +func (HTTPTransport) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.transport.http", + New: func() caddy.Module { return new(HTTPTransport) }, + } +} + +// Provision sets up h.RoundTripper with a http.Transport +// that is ready to use. +func (h *HTTPTransport) Provision(_ caddy.Context) error { + dialer := &net.Dialer{ + Timeout: time.Duration(h.DialTimeout), + FallbackDelay: time.Duration(h.FallbackDelay), + // TODO: Resolver + } + + rt := &http.Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + // the proper dialing information should be embedded into the request's context + if dialInfoVal := ctx.Value(DialInfoCtxKey); dialInfoVal != nil { + dialInfo := dialInfoVal.(DialInfo) + network = dialInfo.Network + address = dialInfo.Address + } + return dialer.DialContext(ctx, network, address) + }, + MaxConnsPerHost: h.MaxConnsPerHost, + ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout), + ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout), + MaxResponseHeaderBytes: h.MaxResponseHeaderSize, + WriteBufferSize: h.WriteBufferSize, + ReadBufferSize: h.ReadBufferSize, + } + + if h.TLS != nil { + rt.TLSHandshakeTimeout = time.Duration(h.TLS.HandshakeTimeout) + + var err error + rt.TLSClientConfig, err = h.TLS.MakeTLSClientConfig() + if err != nil { + return fmt.Errorf("making TLS client config: %v", err) + } + } + + if h.KeepAlive != nil { + dialer.KeepAlive = time.Duration(h.KeepAlive.ProbeInterval) + if enabled := h.KeepAlive.Enabled; enabled != nil { + rt.DisableKeepAlives = !*enabled + } + rt.MaxIdleConns = h.KeepAlive.MaxIdleConns + rt.MaxIdleConnsPerHost = h.KeepAlive.MaxIdleConnsPerHost + rt.IdleConnTimeout = time.Duration(h.KeepAlive.IdleConnTimeout) + } + + if h.Compression != nil { + rt.DisableCompression = !*h.Compression + } + + h.RoundTripper = rt + + return nil +} + +// RoundTrip implements http.RoundTripper with h.RoundTripper. +func (h HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return h.RoundTripper.RoundTrip(req) +} + +// TLSConfig holds configuration related to the +// TLS configuration for the transport/client. +type TLSConfig struct { + RootCAPool []string `json:"root_ca_pool,omitempty"` + // TODO: Should the client cert+key config use caddytls.CertificateLoader modules? + ClientCertificateFile string `json:"client_certificate_file,omitempty"` + ClientCertificateKeyFile string `json:"client_certificate_key_file,omitempty"` + InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` + HandshakeTimeout caddy.Duration `json:"handshake_timeout,omitempty"` +} + +// MakeTLSClientConfig returns a tls.Config usable by a client to a backend. +// If there is no custom TLS configuration, a nil config may be returned. +func (t TLSConfig) MakeTLSClientConfig() (*tls.Config, error) { + cfg := new(tls.Config) + + // client auth + if t.ClientCertificateFile != "" && t.ClientCertificateKeyFile == "" { + return nil, fmt.Errorf("client_certificate_file specified without client_certificate_key_file") + } + if t.ClientCertificateFile == "" && t.ClientCertificateKeyFile != "" { + return nil, fmt.Errorf("client_certificate_key_file specified without client_certificate_file") + } + if t.ClientCertificateFile != "" && t.ClientCertificateKeyFile != "" { + cert, err := tls.LoadX509KeyPair(t.ClientCertificateFile, t.ClientCertificateKeyFile) + if err != nil { + return nil, fmt.Errorf("loading client certificate key pair: %v", err) + } + cfg.Certificates = []tls.Certificate{cert} + } + + // trusted root CAs + if len(t.RootCAPool) > 0 { + rootPool := x509.NewCertPool() + for _, encodedCACert := range t.RootCAPool { + caCert, err := decodeBase64DERCert(encodedCACert) + if err != nil { + return nil, fmt.Errorf("parsing CA certificate: %v", err) + } + rootPool.AddCert(caCert) + } + cfg.RootCAs = rootPool + } + + // throw all security out the window + cfg.InsecureSkipVerify = t.InsecureSkipVerify + + // only return a config if it's not empty + if reflect.DeepEqual(cfg, new(tls.Config)) { + return nil, nil + } + + cfg.NextProtos = []string{"h2", "http/1.1"} // TODO: ensure that this actually enables HTTP/2 + + return cfg, nil +} + +// decodeBase64DERCert base64-decodes, then DER-decodes, certStr. +func decodeBase64DERCert(certStr string) (*x509.Certificate, error) { + // decode base64 + derBytes, err := base64.StdEncoding.DecodeString(certStr) + if err != nil { + return nil, err + } + + // parse the DER-encoded certificate + return x509.ParseCertificate(derBytes) +} + +// KeepAlive holds configuration pertaining to HTTP Keep-Alive. +type KeepAlive struct { + Enabled *bool `json:"enabled,omitempty"` + ProbeInterval caddy.Duration `json:"probe_interval,omitempty"` + MaxIdleConns int `json:"max_idle_conns,omitempty"` + MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"` + IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle +} + +// Interface guards +var ( + _ caddy.Provisioner = (*HTTPTransport)(nil) + _ http.RoundTripper = (*HTTPTransport)(nil) +) diff --git a/modules/caddyhttp/reverseproxy/module.go b/modules/caddyhttp/reverseproxy/module.go deleted file mode 100755 index 21aca1d..0000000 --- a/modules/caddyhttp/reverseproxy/module.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2015 Matthew Holt and The Caddy Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package reverseproxy - -import ( - "github.com/caddyserver/caddy/v2" - "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" - "github.com/caddyserver/caddy/v2/modules/caddyhttp" -) - -func init() { - caddy.RegisterModule(new(LoadBalanced)) - httpcaddyfile.RegisterHandlerDirective("reverse_proxy", parseCaddyfile) // TODO: "proxy"? -} - -// CaddyModule returns the Caddy module information. -func (*LoadBalanced) CaddyModule() caddy.ModuleInfo { - return caddy.ModuleInfo{ - Name: "http.handlers.reverse_proxy", - New: func() caddy.Module { return new(LoadBalanced) }, - } -} - -// parseCaddyfile sets up the handler from Caddyfile tokens. Syntax: -// -// proxy [<matcher>] <to> -// -// TODO: This needs to be finished. It definitely needs to be able to open a block... -func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) { - lb := new(LoadBalanced) - for h.Next() { - allTo := h.RemainingArgs() - if len(allTo) == 0 { - return nil, h.ArgErr() - } - for _, to := range allTo { - lb.Upstreams = append(lb.Upstreams, &UpstreamConfig{Host: to}) - } - } - return lb, nil -} diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 68393de..5a37613 100755..100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -16,227 +16,304 @@ package reverseproxy import ( "context" + "encoding/json" "fmt" - "io" - "log" "net" "net/http" - "net/url" + "regexp" "strings" - "sync" "time" + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" "golang.org/x/net/http/httpguts" ) -// ReverseProxy is an HTTP Handler that takes an incoming request and -// sends it to another server, proxying the response back to the -// client. -type ReverseProxy struct { - // Director must be a function which modifies - // the request into a new request to be sent - // using Transport. Its response is then copied - // back to the original client unmodified. - // Director must not access the provided Request - // after returning. - Director func(*http.Request) - - // The transport used to perform proxy requests. - // If nil, http.DefaultTransport is used. - Transport http.RoundTripper - - // FlushInterval specifies the flush interval - // to flush to the client while copying the - // response body. - // If zero, no periodic flushing is done. - // A negative value means to flush immediately - // after each write to the client. - // The FlushInterval is ignored when ReverseProxy - // recognizes a response as a streaming response; - // for such responses, writes are flushed to the client - // immediately. - FlushInterval time.Duration - - // ErrorLog specifies an optional logger for errors - // that occur when attempting to proxy the request. - // If nil, logging goes to os.Stderr via the log package's - // standard logger. - ErrorLog *log.Logger - - // BufferPool optionally specifies a buffer pool to - // get byte slices for use by io.CopyBuffer when - // copying HTTP response bodies. - BufferPool BufferPool - - // ModifyResponse is an optional function that modifies the - // Response from the backend. It is called if the backend - // returns a response at all, with any HTTP status code. - // If the backend is unreachable, the optional ErrorHandler is - // called without any call to ModifyResponse. - // - // If ModifyResponse returns an error, ErrorHandler is called - // with its error value. If ErrorHandler is nil, its default - // implementation is used. - ModifyResponse func(*http.Response) error - - // ErrorHandler is an optional function that handles errors - // reaching the backend or errors from ModifyResponse. - // - // If nil, the default is to log the provided error and return - // a 502 Status Bad Gateway response. - ErrorHandler func(http.ResponseWriter, *http.Request, error) +func init() { + caddy.RegisterModule(Handler{}) } -// A BufferPool is an interface for getting and returning temporary -// byte slices for use by io.CopyBuffer. -type BufferPool interface { - Get() []byte - Put([]byte) +// Handler implements a highly configurable and production-ready reverse proxy. +type Handler struct { + TransportRaw json.RawMessage `json:"transport,omitempty"` + CBRaw json.RawMessage `json:"circuit_breaker,omitempty"` + LoadBalancing *LoadBalancing `json:"load_balancing,omitempty"` + HealthChecks *HealthChecks `json:"health_checks,omitempty"` + Upstreams UpstreamPool `json:"upstreams,omitempty"` + FlushInterval caddy.Duration `json:"flush_interval,omitempty"` + + Transport http.RoundTripper `json:"-"` + CB CircuitBreaker `json:"-"` } -func singleJoiningSlash(a, b string) string { - aslash := strings.HasSuffix(a, "/") - bslash := strings.HasPrefix(b, "/") - switch { - case aslash && bslash: - return a + b[1:] - case !aslash && !bslash: - return a + "/" + b +// CaddyModule returns the Caddy module information. +func (Handler) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy", + New: func() caddy.Module { return new(Handler) }, } - return a + b } -// NewSingleHostReverseProxy returns a new ReverseProxy that routes -// URLs to the scheme, host, and base path provided in target. If the -// target's path is "/base" and the incoming request was for "/dir", -// the target request will be for /base/dir. -// NewSingleHostReverseProxy does not rewrite the Host header. -// To rewrite Host headers, use ReverseProxy directly with a custom -// Director policy. -func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { - targetQuery := target.RawQuery - director := func(req *http.Request) { - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host - req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) - if targetQuery == "" || req.URL.RawQuery == "" { - req.URL.RawQuery = targetQuery + req.URL.RawQuery - } else { - req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery +// Provision ensures that h is set up properly before use. +func (h *Handler) Provision(ctx caddy.Context) error { + // start by loading modules + if h.TransportRaw != nil { + val, err := ctx.LoadModuleInline("protocol", "http.handlers.reverse_proxy.transport", h.TransportRaw) + if err != nil { + return fmt.Errorf("loading transport module: %s", err) } - if _, ok := req.Header["User-Agent"]; !ok { - // explicitly disable User-Agent so it's not set to default value - req.Header.Set("User-Agent", "") + h.Transport = val.(http.RoundTripper) + h.TransportRaw = nil // allow GC to deallocate - TODO: Does this help? + } + if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil { + val, err := ctx.LoadModuleInline("policy", + "http.handlers.reverse_proxy.selection_policies", + h.LoadBalancing.SelectionPolicyRaw) + if err != nil { + return fmt.Errorf("loading load balancing selection module: %s", err) } + h.LoadBalancing.SelectionPolicy = val.(Selector) + h.LoadBalancing.SelectionPolicyRaw = nil // allow GC to deallocate - TODO: Does this help? + } + if h.CBRaw != nil { + val, err := ctx.LoadModuleInline("type", "http.handlers.reverse_proxy.circuit_breakers", h.CBRaw) + if err != nil { + return fmt.Errorf("loading circuit breaker module: %s", err) + } + h.CB = val.(CircuitBreaker) + h.CBRaw = nil // allow GC to deallocate - TODO: Does this help? } - return &ReverseProxy{Director: director} -} -func copyHeader(dst, src http.Header) { - for k, vv := range src { - for _, v := range vv { - dst.Add(k, v) + if h.Transport == nil { + t := &HTTPTransport{ + KeepAlive: &KeepAlive{ + ProbeInterval: caddy.Duration(30 * time.Second), + IdleConnTimeout: caddy.Duration(2 * time.Minute), + }, + DialTimeout: caddy.Duration(10 * time.Second), } + err := t.Provision(ctx) + if err != nil { + return fmt.Errorf("provisioning default transport: %v", err) + } + h.Transport = t } -} -func cloneHeader(h http.Header) http.Header { - h2 := make(http.Header, len(h)) - for k, vv := range h { - vv2 := make([]string, len(vv)) - copy(vv2, vv) - h2[k] = vv2 + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + if h.LoadBalancing.SelectionPolicy == nil { + h.LoadBalancing.SelectionPolicy = RandomSelection{} + } + if h.LoadBalancing.TryDuration > 0 && h.LoadBalancing.TryInterval == 0 { + // a non-zero try_duration with a zero try_interval + // will always spin the CPU for try_duration if the + // upstream is local or low-latency; avoid that by + // defaulting to a sane wait period between attempts + h.LoadBalancing.TryInterval = caddy.Duration(250 * time.Millisecond) } - return h2 -} -// Hop-by-hop headers. These are removed when sent to the backend. -// As of RFC 7230, hop-by-hop headers are required to appear in the -// Connection header field. These are the headers defined by the -// obsoleted RFC 2616 (section 13.5.1) and are used for backward -// compatibility. -var hopHeaders = []string{ - "Connection", - "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Authorization", - "Te", // canonicalized version of "TE" - "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 - "Transfer-Encoding", - "Upgrade", -} + // if active health checks are enabled, configure them and start a worker + if h.HealthChecks != nil && + h.HealthChecks.Active != nil && + (h.HealthChecks.Active.Path != "" || h.HealthChecks.Active.Port != 0) { + timeout := time.Duration(h.HealthChecks.Active.Timeout) + if timeout == 0 { + timeout = 10 * time.Second + } -func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) { - p.logf("http: proxy error: %v", err) - rw.WriteHeader(http.StatusBadGateway) -} + h.HealthChecks.Active.stopChan = make(chan struct{}) + h.HealthChecks.Active.httpClient = &http.Client{ + Timeout: timeout, + Transport: h.Transport, + } + + if h.HealthChecks.Active.Interval == 0 { + h.HealthChecks.Active.Interval = caddy.Duration(30 * time.Second) + } + + if h.HealthChecks.Active.ExpectBody != "" { + var err error + h.HealthChecks.Active.bodyRegexp, err = regexp.Compile(h.HealthChecks.Active.ExpectBody) + if err != nil { + return fmt.Errorf("expect_body: compiling regular expression: %v", err) + } + } + + go h.activeHealthChecker() + } -func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) { - if p.ErrorHandler != nil { - return p.ErrorHandler + var allUpstreams []*Upstream + for _, upstream := range h.Upstreams { + // upstreams are allowed to map to only a single host, + // but an upstream's address may semantically represent + // multiple addresses, so make sure to handle each + // one in turn based on this one upstream config + network, addresses, err := caddy.ParseNetworkAddress(upstream.Dial) + if err != nil { + return fmt.Errorf("parsing dial address: %v", err) + } + + for _, addr := range addresses { + // make a new upstream based on the original + // that has a singular dial address + upstreamCopy := *upstream + upstreamCopy.dialInfo = DialInfo{network, addr} + upstreamCopy.Dial = upstreamCopy.dialInfo.String() + upstreamCopy.cb = h.CB + + // if host already exists from a current config, + // use that instead; otherwise, add it + // TODO: make hosts modular, so that their state can be distributed in enterprise for example + // TODO: If distributed, the pool should be stored in storage... + var host Host = new(upstreamHost) + activeHost, loaded := hosts.LoadOrStore(upstreamCopy.Dial, host) + if loaded { + host = activeHost.(Host) + } + upstreamCopy.Host = host + + // if the passive health checker has a non-zero "unhealthy + // request count" but the upstream has no MaxRequests set + // (they are the same thing, but one is a default value for + // for upstreams with a zero MaxRequests), copy the default + // value into this upstream, since the value in the upstream + // (MaxRequests) is what is used during availability checks + if h.HealthChecks != nil && + h.HealthChecks.Passive != nil && + h.HealthChecks.Passive.UnhealthyRequestCount > 0 && + upstreamCopy.MaxRequests == 0 { + upstreamCopy.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount + } + + // upstreams need independent access to the passive + // health check policy because they run outside of the + // scope of a request handler + if h.HealthChecks != nil { + upstreamCopy.healthCheckPolicy = h.HealthChecks.Passive + } + + allUpstreams = append(allUpstreams, &upstreamCopy) + } } - return p.defaultErrorHandler + + // replace the unmarshaled upstreams (possible 1:many + // address mapping) with our list, which is mapped 1:1, + // thus may have expanded the original list + h.Upstreams = allUpstreams + + return nil } -// modifyResponse conditionally runs the optional ModifyResponse hook -// and reports whether the request should proceed. -func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool { - if p.ModifyResponse == nil { - return true +// Cleanup cleans up the resources made by h during provisioning. +func (h *Handler) Cleanup() error { + // stop the active health checker + if h.HealthChecks != nil && + h.HealthChecks.Active != nil && + h.HealthChecks.Active.stopChan != nil { + close(h.HealthChecks.Active.stopChan) } - if err := p.ModifyResponse(res); err != nil { - res.Body.Close() - p.getErrorHandler()(rw, req, err) - return false + + // remove hosts from our config from the pool + for _, upstream := range h.Upstreams { + hosts.Delete(upstream.dialInfo.String()) } - return true + + return nil } -func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*http.Response, error) { - transport := p.Transport - if transport == nil { - transport = http.DefaultTransport - } - - ctx := req.Context() - if cn, ok := rw.(http.CloseNotifier); ok { - var cancel context.CancelFunc - ctx, cancel = context.WithCancel(ctx) - defer cancel() - notifyChan := cn.CloseNotify() - go func() { - select { - case <-notifyChan: - cancel() - case <-ctx.Done(): +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { + // prepare the request for proxying; this is needed only once + err := h.prepareRequest(r) + if err != nil { + return caddyhttp.Error(http.StatusInternalServerError, + fmt.Errorf("preparing request for upstream round-trip: %v", err)) + } + + start := time.Now() + + var proxyErr error + for { + // choose an available upstream + upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, r) + if upstream == nil { + if proxyErr == nil { + proxyErr = fmt.Errorf("no available upstreams") + } + if !h.tryAgain(start, proxyErr) { + break } - }() + continue + } + + // attach to the request information about how to dial the upstream; + // this is necessary because the information cannot be sufficiently + // or satisfactorily represented in a URL + ctx := context.WithValue(r.Context(), DialInfoCtxKey, upstream.dialInfo) + r = r.WithContext(ctx) + + // proxy the request to that upstream + proxyErr = h.reverseProxy(w, r, upstream) + if proxyErr == nil || proxyErr == context.Canceled { + // context.Canceled happens when the downstream client + // cancels the request; we don't have to worry about that + return nil + } + + // remember this failure (if enabled) + h.countFailure(upstream) + + // if we've tried long enough, break + if !h.tryAgain(start, proxyErr) { + break + } + } + + return caddyhttp.Error(http.StatusBadGateway, proxyErr) +} + +// prepareRequest modifies req so that it is ready to be proxied, +// except for directing to a specific upstream. This method mutates +// headers and other necessary properties of the request and should +// be done just once (before proxying) regardless of proxy retries. +// This assumes that no mutations of the request are performed +// by h during or after proxying. +func (h Handler) prepareRequest(req *http.Request) error { + // as a special (but very common) case, if the transport + // is HTTP, then ensure the request has the proper scheme + // because incoming requests by default are lacking it + if req.URL.Scheme == "" { + req.URL.Scheme = "http" + if ht, ok := h.Transport.(*HTTPTransport); ok && ht.TLS != nil { + req.URL.Scheme = "https" + } } - outreq := req.WithContext(ctx) // includes shallow copies of maps, but okay if req.ContentLength == 0 { - outreq.Body = nil // Issue 16036: nil Body for http.Transport retries + req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries } - outreq.Header = cloneHeader(req.Header) + req.Close = false - p.Director(outreq) - outreq.Close = false + // if User-Agent is not set by client, then explicitly + // disable it so it's not set to default value by std lib + if _, ok := req.Header["User-Agent"]; !ok { + req.Header.Set("User-Agent", "") + } - reqUpType := upgradeType(outreq.Header) - removeConnectionHeaders(outreq.Header) + reqUpType := upgradeType(req.Header) + removeConnectionHeaders(req.Header) // Remove hop-by-hop headers to the backend. Especially // important is "Connection" because we want a persistent // connection, regardless of what the client sent to us. for _, h := range hopHeaders { - hv := outreq.Header.Get(h) + hv := req.Header.Get(h) if hv == "" { continue } if h == "Te" && hv == "trailers" { - // Issue 21096: tell backend applications that + // Issue golang/go#21096: tell backend applications that // care about trailer support that we support // trailers. (We do, but we don't go out of // our way to advertise that unless the @@ -244,40 +321,72 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht // worth mentioning) continue } - outreq.Header.Del(h) + req.Header.Del(h) } // After stripping all the hop-by-hop connection headers above, add back any // necessary for protocol upgrades, such as for websockets. if reqUpType != "" { - outreq.Header.Set("Connection", "Upgrade") - outreq.Header.Set("Upgrade", reqUpType) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", reqUpType) } if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { // If we aren't the first proxy retain prior // X-Forwarded-For information as a comma+space // separated list and fold multiple headers into one. - if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + if prior, ok := req.Header["X-Forwarded-For"]; ok { clientIP = strings.Join(prior, ", ") + ", " + clientIP } - outreq.Header.Set("X-Forwarded-For", clientIP) + req.Header.Set("X-Forwarded-For", clientIP) } - res, err := transport.RoundTrip(outreq) + return nil +} + +// reverseProxy performs a round-trip to the given backend and processes the response with the client. +// (This method is mostly the beginning of what was borrowed from the net/http/httputil package in the +// Go standard library which was used as the foundation.) +func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, upstream *Upstream) error { + upstream.Host.CountRequest(1) + defer upstream.Host.CountRequest(-1) + + // point the request to this upstream + h.directRequest(req, upstream) + + // do the round-trip + start := time.Now() + res, err := h.Transport.RoundTrip(req) + latency := time.Since(start) if err != nil { - p.getErrorHandler()(rw, outreq, err) - return nil, err + return err } - // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) - if res.StatusCode == http.StatusSwitchingProtocols { - if !p.modifyResponse(rw, res, outreq) { - return res, nil + // update circuit breaker on current conditions + if upstream.cb != nil { + upstream.cb.RecordMetric(res.StatusCode, latency) + } + + // perform passive health checks (if enabled) + if h.HealthChecks != nil && h.HealthChecks.Passive != nil { + // strike if the status code matches one that is "bad" + for _, badStatus := range h.HealthChecks.Passive.UnhealthyStatus { + if caddyhttp.StatusCodeMatches(res.StatusCode, badStatus) { + h.countFailure(upstream) + } } - p.handleUpgradeResponse(rw, outreq, res) - return res, nil + // strike if the roundtrip took too long + if h.HealthChecks.Passive.UnhealthyLatency > 0 && + latency >= time.Duration(h.HealthChecks.Passive.UnhealthyLatency) { + h.countFailure(upstream) + } + } + + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) + if res.StatusCode == http.StatusSwitchingProtocols { + h.handleUpgradeResponse(rw, req, res) + return nil } removeConnectionHeaders(res.Header) @@ -286,10 +395,6 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht res.Header.Del(h) } - if !p.modifyResponse(rw, res, outreq) { - return res, nil - } - copyHeader(rw.Header(), res.Header) // The "Trailer" header isn't included in the Transport's response, @@ -305,15 +410,16 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht rw.WriteHeader(res.StatusCode) - err = p.copyResponse(rw, res.Body, p.flushInterval(req, res)) + err = h.copyResponse(rw, res.Body, h.flushInterval(req, res)) if err != nil { defer res.Body.Close() // Since we're streaming the response, if we run into an error all we can do - // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler + // is abort the request. Issue golang/go#23643: ReverseProxy should use ErrAbortHandler // on read error while copying body. + // TODO: Look into whether we want to panic at all in our case... if !shouldPanicOnCopyError(req) { - p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) - return nil, err + // p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) + return err } panic(http.ErrAbortHandler) @@ -331,7 +437,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht if len(res.Trailer) == announcedTrailers { copyHeader(rw.Header(), res.Trailer) - return res, nil + return nil } for k, vv := range res.Trailer { @@ -341,21 +447,48 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht } } - return res, nil + return nil +} + +// tryAgain takes the time that the handler was initially invoked +// as well as any error currently obtained and returns true if +// another attempt should be made at proxying the request. If +// true is returned, it has already blocked long enough before +// the next retry (i.e. no more sleeping is needed). If false is +// returned, the handler should stop trying to proxy the request. +func (h Handler) tryAgain(start time.Time, proxyErr error) bool { + // if downstream has canceled the request, break + if proxyErr == context.Canceled { + return false + } + // if we've tried long enough, break + if time.Since(start) >= time.Duration(h.LoadBalancing.TryDuration) { + return false + } + // otherwise, wait and try the next available host + time.Sleep(time.Duration(h.LoadBalancing.TryInterval)) + return true } -var inOurTests bool // whether we're in our own tests +// directRequest modifies only req.URL so that it points to the +// given upstream host. It must modify ONLY the request URL. +func (h Handler) directRequest(req *http.Request, upstream *Upstream) { + if req.URL.Host == "" { + req.URL.Host = upstream.dialInfo.Address + } +} // shouldPanicOnCopyError reports whether the reverse proxy should // panic with http.ErrAbortHandler. This is the right thing to do by // default, but Go 1.10 and earlier did not, so existing unit tests // weren't expecting panics. Only panic in our own tests, or when // running under the HTTP server. +// TODO: I don't know if we want this at all... func shouldPanicOnCopyError(req *http.Request) bool { - if inOurTests { - // Our tests know to handle this panic. - return true - } + // if inOurTests { + // // Our tests know to handle this panic. + // return true + // } if req.Context().Value(http.ServerContextKey) != nil { // We seem to be running under an HTTP server, so // it'll recover the panic. @@ -366,146 +499,22 @@ func shouldPanicOnCopyError(req *http.Request) bool { return false } -// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. -// See RFC 7230, section 6.1 -func removeConnectionHeaders(h http.Header) { - if c := h.Get("Connection"); c != "" { - for _, f := range strings.Split(c, ",") { - if f = strings.TrimSpace(f); f != "" { - h.Del(f) - } - } - } -} - -// flushInterval returns the p.FlushInterval value, conditionally -// overriding its value for a specific request/response. -func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration { - resCT := res.Header.Get("Content-Type") - - // For Server-Sent Events responses, flush immediately. - // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream - if resCT == "text/event-stream" { - return -1 // negative means immediately - } - - // TODO: more specific cases? e.g. res.ContentLength == -1? - return p.FlushInterval -} - -func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { - if flushInterval != 0 { - if wf, ok := dst.(writeFlusher); ok { - mlw := &maxLatencyWriter{ - dst: wf, - latency: flushInterval, - } - defer mlw.stop() - dst = mlw - } - } - - var buf []byte - if p.BufferPool != nil { - buf = p.BufferPool.Get() - defer p.BufferPool.Put(buf) - } - _, err := p.copyBuffer(dst, src, buf) - return err -} - -// copyBuffer returns any write errors or non-EOF read errors, and the amount -// of bytes written. -func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { - if len(buf) == 0 { - buf = make([]byte, 32*1024) - } - var written int64 - for { - nr, rerr := src.Read(buf) - if rerr != nil && rerr != io.EOF && rerr != context.Canceled { - p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) - } - if nr > 0 { - nw, werr := dst.Write(buf[:nr]) - if nw > 0 { - written += int64(nw) - } - if werr != nil { - return written, werr - } - if nr != nw { - return written, io.ErrShortWrite - } - } - if rerr != nil { - if rerr == io.EOF { - rerr = nil - } - return written, rerr +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) } } } -func (p *ReverseProxy) logf(format string, args ...interface{}) { - if p.ErrorLog != nil { - p.ErrorLog.Printf(format, args...) - } else { - log.Printf(format, args...) - } -} - -type writeFlusher interface { - io.Writer - http.Flusher -} - -type maxLatencyWriter struct { - dst writeFlusher - latency time.Duration // non-zero; negative means to flush immediately - - mu sync.Mutex // protects t, flushPending, and dst.Flush - t *time.Timer - flushPending bool -} - -func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { - m.mu.Lock() - defer m.mu.Unlock() - n, err = m.dst.Write(p) - if m.latency < 0 { - m.dst.Flush() - return - } - if m.flushPending { - return - } - if m.t == nil { - m.t = time.AfterFunc(m.latency, m.delayedFlush) - } else { - m.t.Reset(m.latency) - } - m.flushPending = true - return -} - -func (m *maxLatencyWriter) delayedFlush() { - m.mu.Lock() - defer m.mu.Unlock() - if !m.flushPending { // if stop was called but AfterFunc already started this goroutine - return - } - m.dst.Flush() - m.flushPending = false -} - -func (m *maxLatencyWriter) stop() { - m.mu.Lock() - defer m.mu.Unlock() - m.flushPending = false - if m.t != nil { - m.t.Stop() +func cloneHeader(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 } + return h2 } func upgradeType(h http.Header) string { @@ -515,62 +524,71 @@ func upgradeType(h http.Header) string { return strings.ToLower(h.Get("Upgrade")) } -func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { - reqUpType := upgradeType(req.Header) - resUpType := upgradeType(res.Header) - if reqUpType != resUpType { - p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) - return +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b } + return a + b +} - copyHeader(res.Header, rw.Header()) - - hj, ok := rw.(http.Hijacker) - if !ok { - p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) - return - } - backConn, ok := res.Body.(io.ReadWriteCloser) - if !ok { - p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) - return +// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. +// See RFC 7230, section 6.1 +func removeConnectionHeaders(h http.Header) { + if c := h.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + h.Del(f) + } + } } - defer backConn.Close() - conn, brw, err := hj.Hijack() - if err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) - return - } - defer conn.Close() - res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above - if err := res.Write(brw); err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) - return - } - if err := brw.Flush(); err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) - return - } - errc := make(chan error, 1) - spc := switchProtocolCopier{user: conn, backend: backConn} - go spc.copyToBackend(errc) - go spc.copyFromBackend(errc) - <-errc - return } -// switchProtocolCopier exists so goroutines proxying data back and -// forth have nice names in stacks. -type switchProtocolCopier struct { - user, backend io.ReadWriter +// LoadBalancing has parameters related to load balancing. +type LoadBalancing struct { + SelectionPolicyRaw json.RawMessage `json:"selection_policy,omitempty"` + TryDuration caddy.Duration `json:"try_duration,omitempty"` + TryInterval caddy.Duration `json:"try_interval,omitempty"` + + SelectionPolicy Selector `json:"-"` } -func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { - _, err := io.Copy(c.user, c.backend) - errc <- err +// Selector selects an available upstream from the pool. +type Selector interface { + Select(UpstreamPool, *http.Request) *Upstream } -func (c switchProtocolCopier) copyToBackend(errc chan<- error) { - _, err := io.Copy(c.backend, c.user) - errc <- err +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", } + +// TODO: see if we can use this +// var bufPool = sync.Pool{ +// New: func() interface{} { +// return new(bytes.Buffer) +// }, +// } + +// Interface guards +var ( + _ caddy.Provisioner = (*Handler)(nil) + _ caddy.CleanerUpper = (*Handler)(nil) + _ caddyhttp.MiddlewareHandler = (*Handler)(nil) +) diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go new file mode 100644 index 0000000..5bb2d62 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go @@ -0,0 +1,353 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import ( + "fmt" + "hash/fnv" + weakrand "math/rand" + "net" + "net/http" + "sync/atomic" + "time" + + "github.com/caddyserver/caddy/v2" +) + +func init() { + caddy.RegisterModule(RandomSelection{}) + caddy.RegisterModule(RandomChoiceSelection{}) + caddy.RegisterModule(LeastConnSelection{}) + caddy.RegisterModule(RoundRobinSelection{}) + caddy.RegisterModule(FirstSelection{}) + caddy.RegisterModule(IPHashSelection{}) + caddy.RegisterModule(URIHashSelection{}) + caddy.RegisterModule(HeaderHashSelection{}) + + weakrand.Seed(time.Now().UTC().UnixNano()) +} + +// RandomSelection is a policy that selects +// an available host at random. +type RandomSelection struct{} + +// CaddyModule returns the Caddy module information. +func (RandomSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.random", + New: func() caddy.Module { return new(RandomSelection) }, + } +} + +// Select returns an available host, if any. +func (r RandomSelection) Select(pool UpstreamPool, request *http.Request) *Upstream { + // use reservoir sampling because the number of available + // hosts isn't known: https://en.wikipedia.org/wiki/Reservoir_sampling + var randomHost *Upstream + var count int + for _, upstream := range pool { + if !upstream.Available() { + continue + } + // (n % 1 == 0) holds for all n, therefore a + // upstream will always be chosen if there is at + // least one available + count++ + if (weakrand.Int() % count) == 0 { + randomHost = upstream + } + } + return randomHost +} + +// RandomChoiceSelection is a policy that selects +// two or more available hosts at random, then +// chooses the one with the least load. +type RandomChoiceSelection struct { + Choose int `json:"choose,omitempty"` +} + +// CaddyModule returns the Caddy module information. +func (RandomChoiceSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.random_choose", + New: func() caddy.Module { return new(RandomChoiceSelection) }, + } +} + +// Provision sets up r. +func (r *RandomChoiceSelection) Provision(ctx caddy.Context) error { + if r.Choose == 0 { + r.Choose = 2 + } + return nil +} + +// Validate ensures that r's configuration is valid. +func (r RandomChoiceSelection) Validate() error { + if r.Choose < 2 { + return fmt.Errorf("choose must be at least 2") + } + return nil +} + +// Select returns an available host, if any. +func (r RandomChoiceSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream { + k := r.Choose + if k > len(pool) { + k = len(pool) + } + choices := make([]*Upstream, k) + for i, upstream := range pool { + if !upstream.Available() { + continue + } + j := weakrand.Intn(i) + if j < k { + choices[j] = upstream + } + } + return leastRequests(choices) +} + +// LeastConnSelection is a policy that selects the +// host with the least active requests. If multiple +// hosts have the same fewest number, one is chosen +// randomly. The term "conn" or "connection" is used +// in this policy name due to its similar meaning in +// other software, but our load balancer actually +// counts active requests rather than connections, +// since these days requests are multiplexed onto +// shared connections. +type LeastConnSelection struct{} + +// CaddyModule returns the Caddy module information. +func (LeastConnSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.least_conn", + New: func() caddy.Module { return new(LeastConnSelection) }, + } +} + +// Select selects the up host with the least number of connections in the +// pool. If more than one host has the same least number of connections, +// one of the hosts is chosen at random. +func (LeastConnSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream { + var bestHost *Upstream + var count int + leastReqs := -1 + + for _, host := range pool { + if !host.Available() { + continue + } + numReqs := host.NumRequests() + if leastReqs == -1 || numReqs < leastReqs { + leastReqs = numReqs + count = 0 + } + + // among hosts with same least connections, perform a reservoir + // sample: https://en.wikipedia.org/wiki/Reservoir_sampling + if numReqs == leastReqs { + count++ + if (weakrand.Int() % count) == 0 { + bestHost = host + } + } + } + + return bestHost +} + +// RoundRobinSelection is a policy that selects +// a host based on round-robin ordering. +type RoundRobinSelection struct { + robin uint32 +} + +// CaddyModule returns the Caddy module information. +func (RoundRobinSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.round_robin", + New: func() caddy.Module { return new(RoundRobinSelection) }, + } +} + +// Select returns an available host, if any. +func (r *RoundRobinSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream { + n := uint32(len(pool)) + if n == 0 { + return nil + } + for i := uint32(0); i < n; i++ { + atomic.AddUint32(&r.robin, 1) + host := pool[r.robin%n] + if host.Available() { + return host + } + } + return nil +} + +// FirstSelection is a policy that selects +// the first available host. +type FirstSelection struct{} + +// CaddyModule returns the Caddy module information. +func (FirstSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.first", + New: func() caddy.Module { return new(FirstSelection) }, + } +} + +// Select returns an available host, if any. +func (FirstSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream { + for _, host := range pool { + if host.Available() { + return host + } + } + return nil +} + +// IPHashSelection is a policy that selects a host +// based on hashing the remote IP of the request. +type IPHashSelection struct{} + +// CaddyModule returns the Caddy module information. +func (IPHashSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.ip_hash", + New: func() caddy.Module { return new(IPHashSelection) }, + } +} + +// Select returns an available host, if any. +func (IPHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream { + clientIP, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + clientIP = req.RemoteAddr + } + return hostByHashing(pool, clientIP) +} + +// URIHashSelection is a policy that selects a +// host by hashing the request URI. +type URIHashSelection struct{} + +// CaddyModule returns the Caddy module information. +func (URIHashSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.uri_hash", + New: func() caddy.Module { return new(URIHashSelection) }, + } +} + +// Select returns an available host, if any. +func (URIHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream { + return hostByHashing(pool, req.RequestURI) +} + +// HeaderHashSelection is a policy that selects +// a host based on a given request header. +type HeaderHashSelection struct { + Field string `json:"field,omitempty"` +} + +// CaddyModule returns the Caddy module information. +func (HeaderHashSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.header", + New: func() caddy.Module { return new(HeaderHashSelection) }, + } +} + +// Select returns an available host, if any. +func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream { + if s.Field == "" { + return nil + } + val := req.Header.Get(s.Field) + if val == "" { + return RandomSelection{}.Select(pool, req) + } + return hostByHashing(pool, val) +} + +// leastRequests returns the host with the +// least number of active requests to it. +// If more than one host has the same +// least number of active requests, then +// one of those is chosen at random. +func leastRequests(upstreams []*Upstream) *Upstream { + if len(upstreams) == 0 { + return nil + } + var best []*Upstream + var bestReqs int + for _, upstream := range upstreams { + reqs := upstream.NumRequests() + if reqs == 0 { + return upstream + } + if reqs <= bestReqs { + bestReqs = reqs + best = append(best, upstream) + } + } + return best[weakrand.Intn(len(best))] +} + +// hostByHashing returns an available host +// from pool based on a hashable string s. +func hostByHashing(pool []*Upstream, s string) *Upstream { + poolLen := uint32(len(pool)) + if poolLen == 0 { + return nil + } + index := hash(s) % poolLen + for i := uint32(0); i < poolLen; i++ { + index += i + upstream := pool[index%poolLen] + if upstream.Available() { + return upstream + } + } + return nil +} + +// hash calculates a fast hash based on s. +func hash(s string) uint32 { + h := fnv.New32a() + h.Write([]byte(s)) + return h.Sum32() +} + +// Interface guards +var ( + _ Selector = (*RandomSelection)(nil) + _ Selector = (*RandomChoiceSelection)(nil) + _ Selector = (*LeastConnSelection)(nil) + _ Selector = (*RoundRobinSelection)(nil) + _ Selector = (*FirstSelection)(nil) + _ Selector = (*IPHashSelection)(nil) + _ Selector = (*URIHashSelection)(nil) + _ Selector = (*HeaderHashSelection)(nil) + + _ caddy.Validator = (*RandomChoiceSelection)(nil) + _ caddy.Provisioner = (*RandomChoiceSelection)(nil) +) diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go new file mode 100644 index 0000000..e9939d6 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go @@ -0,0 +1,273 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func testPool() UpstreamPool { + return UpstreamPool{ + {Host: new(upstreamHost)}, + {Host: new(upstreamHost)}, + {Host: new(upstreamHost)}, + } +} + +func TestRoundRobinPolicy(t *testing.T) { + pool := testPool() + rrPolicy := new(RoundRobinSelection) + req, _ := http.NewRequest("GET", "/", nil) + + h := rrPolicy.Select(pool, req) + // First selected host is 1, because counter starts at 0 + // and increments before host is selected + if h != pool[1] { + t.Error("Expected first round robin host to be second host in the pool.") + } + h = rrPolicy.Select(pool, req) + if h != pool[2] { + t.Error("Expected second round robin host to be third host in the pool.") + } + h = rrPolicy.Select(pool, req) + if h != pool[0] { + t.Error("Expected third round robin host to be first host in the pool.") + } + // mark host as down + pool[1].SetHealthy(false) + h = rrPolicy.Select(pool, req) + if h != pool[2] { + t.Error("Expected to skip down host.") + } + // mark host as up + pool[1].SetHealthy(true) + + h = rrPolicy.Select(pool, req) + if h == pool[2] { + t.Error("Expected to balance evenly among healthy hosts") + } + // mark host as full + pool[1].CountRequest(1) + pool[1].MaxRequests = 1 + h = rrPolicy.Select(pool, req) + if h != pool[2] { + t.Error("Expected to skip full host.") + } +} + +func TestLeastConnPolicy(t *testing.T) { + pool := testPool() + lcPolicy := new(LeastConnSelection) + req, _ := http.NewRequest("GET", "/", nil) + + pool[0].CountRequest(10) + pool[1].CountRequest(10) + h := lcPolicy.Select(pool, req) + if h != pool[2] { + t.Error("Expected least connection host to be third host.") + } + pool[2].CountRequest(100) + h = lcPolicy.Select(pool, req) + if h != pool[0] && h != pool[1] { + t.Error("Expected least connection host to be first or second host.") + } +} + +func TestIPHashPolicy(t *testing.T) { + pool := testPool() + ipHash := new(IPHashSelection) + req, _ := http.NewRequest("GET", "/", nil) + + // We should be able to predict where every request is routed. + req.RemoteAddr = "172.0.0.1:80" + h := ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + req.RemoteAddr = "172.0.0.2:80" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + req.RemoteAddr = "172.0.0.3:80" + h = ipHash.Select(pool, req) + if h != pool[2] { + t.Error("Expected ip hash policy host to be the third host.") + } + req.RemoteAddr = "172.0.0.4:80" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + + // we should get the same results without a port + req.RemoteAddr = "172.0.0.1" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + req.RemoteAddr = "172.0.0.2" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + req.RemoteAddr = "172.0.0.3" + h = ipHash.Select(pool, req) + if h != pool[2] { + t.Error("Expected ip hash policy host to be the third host.") + } + req.RemoteAddr = "172.0.0.4" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + + // we should get a healthy host if the original host is unhealthy and a + // healthy host is available + req.RemoteAddr = "172.0.0.1" + pool[1].SetHealthy(false) + h = ipHash.Select(pool, req) + if h != pool[2] { + t.Error("Expected ip hash policy host to be the third host.") + } + + req.RemoteAddr = "172.0.0.2" + h = ipHash.Select(pool, req) + if h != pool[2] { + t.Error("Expected ip hash policy host to be the third host.") + } + pool[1].SetHealthy(true) + + req.RemoteAddr = "172.0.0.3" + pool[2].SetHealthy(false) + h = ipHash.Select(pool, req) + if h != pool[0] { + t.Error("Expected ip hash policy host to be the first host.") + } + req.RemoteAddr = "172.0.0.4" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + + // We should be able to resize the host pool and still be able to predict + // where a req will be routed with the same IP's used above + pool = UpstreamPool{ + {Host: new(upstreamHost)}, + {Host: new(upstreamHost)}, + } + req.RemoteAddr = "172.0.0.1:80" + h = ipHash.Select(pool, req) + if h != pool[0] { + t.Error("Expected ip hash policy host to be the first host.") + } + req.RemoteAddr = "172.0.0.2:80" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + req.RemoteAddr = "172.0.0.3:80" + h = ipHash.Select(pool, req) + if h != pool[0] { + t.Error("Expected ip hash policy host to be the first host.") + } + req.RemoteAddr = "172.0.0.4:80" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + + // We should get nil when there are no healthy hosts + pool[0].SetHealthy(false) + pool[1].SetHealthy(false) + h = ipHash.Select(pool, req) + if h != nil { + t.Error("Expected ip hash policy host to be nil.") + } +} + +func TestFirstPolicy(t *testing.T) { + pool := testPool() + firstPolicy := new(FirstSelection) + req := httptest.NewRequest(http.MethodGet, "/", nil) + + h := firstPolicy.Select(pool, req) + if h != pool[0] { + t.Error("Expected first policy host to be the first host.") + } + + pool[0].SetHealthy(false) + h = firstPolicy.Select(pool, req) + if h != pool[1] { + t.Error("Expected first policy host to be the second host.") + } +} + +func TestURIHashPolicy(t *testing.T) { + pool := testPool() + uriPolicy := new(URIHashSelection) + + request := httptest.NewRequest(http.MethodGet, "/test", nil) + h := uriPolicy.Select(pool, request) + if h != pool[0] { + t.Error("Expected uri policy host to be the first host.") + } + + pool[0].SetHealthy(false) + h = uriPolicy.Select(pool, request) + if h != pool[1] { + t.Error("Expected uri policy host to be the first host.") + } + + request = httptest.NewRequest(http.MethodGet, "/test_2", nil) + h = uriPolicy.Select(pool, request) + if h != pool[1] { + t.Error("Expected uri policy host to be the second host.") + } + + // We should be able to resize the host pool and still be able to predict + // where a request will be routed with the same URI's used above + pool = UpstreamPool{ + {Host: new(upstreamHost)}, + {Host: new(upstreamHost)}, + } + + request = httptest.NewRequest(http.MethodGet, "/test", nil) + h = uriPolicy.Select(pool, request) + if h != pool[0] { + t.Error("Expected uri policy host to be the first host.") + } + + pool[0].SetHealthy(false) + h = uriPolicy.Select(pool, request) + if h != pool[1] { + t.Error("Expected uri policy host to be the first host.") + } + + request = httptest.NewRequest(http.MethodGet, "/test_2", nil) + h = uriPolicy.Select(pool, request) + if h != pool[1] { + t.Error("Expected uri policy host to be the second host.") + } + + pool[0].SetHealthy(false) + pool[1].SetHealthy(false) + h = uriPolicy.Select(pool, request) + if h != nil { + t.Error("Expected uri policy policy host to be nil.") + } +} diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go new file mode 100644 index 0000000..a3b711b --- /dev/null +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -0,0 +1,223 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Most of the code in this file was initially borrowed from the Go +// standard library, which has this copyright notice: +// Copyright 2011 The Go Authors. + +package reverseproxy + +import ( + "context" + "io" + "net/http" + "sync" + "time" +) + +func (h Handler) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { + reqUpType := upgradeType(req.Header) + resUpType := upgradeType(res.Header) + if reqUpType != resUpType { + // TODO: figure out our own error handling + // p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) + return + } + + copyHeader(res.Header, rw.Header()) + + hj, ok := rw.(http.Hijacker) + if !ok { + // p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + // p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) + return + } + defer backConn.Close() + conn, brw, err := hj.Hijack() + if err != nil { + // p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) + return + } + defer conn.Close() + res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above + if err := res.Write(brw); err != nil { + // p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) + return + } + if err := brw.Flush(); err != nil { + // p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) + return + } + errc := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + <-errc + return +} + +// flushInterval returns the p.FlushInterval value, conditionally +// overriding its value for a specific request/response. +func (h Handler) flushInterval(req *http.Request, res *http.Response) time.Duration { + resCT := res.Header.Get("Content-Type") + + // For Server-Sent Events responses, flush immediately. + // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream + if resCT == "text/event-stream" { + return -1 // negative means immediately + } + + // TODO: more specific cases? e.g. res.ContentLength == -1? (this TODO is from the std lib) + return time.Duration(h.FlushInterval) +} + +func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { + if flushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: flushInterval, + } + defer mlw.stop() + dst = mlw + } + } + + // TODO: Figure out how we want to do this... using custom buffer pool type seems unnecessary + // or maybe it is, depending on how we want to handle errors, + // see: https://github.com/golang/go/issues/21814 + // buf := bufPool.Get().(*bytes.Buffer) + // buf.Reset() + // defer bufPool.Put(buf) + // _, err := io.CopyBuffer(dst, src, ) + var buf []byte + // if h.BufferPool != nil { + // buf = h.BufferPool.Get() + // defer h.BufferPool.Put(buf) + // } + _, err := h.copyBuffer(dst, src, buf) + return err +} + +// copyBuffer returns any write errors or non-EOF read errors, and the amount +// of bytes written. +func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && rerr != io.EOF && rerr != context.Canceled { + // TODO: this could be useful to know (indeed, it revealed an error in our + // fastcgi PoC earlier; but it's this single error report here that necessitates + // a function separate from io.CopyBuffer, since io.CopyBuffer does not distinguish + // between read or write errors; in a reverse proxy situation, write errors are not + // something we need to report to the client, but read errors are a problem on our + // end for sure. so we need to decide what we want.) + // p.logf("copyBuffer: ReverseProxy read error during body copy: %v", rerr) + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + if rerr == io.EOF { + rerr = nil + } + return written, rerr + } + } +} + +type writeFlusher interface { + io.Writer + http.Flusher +} + +type maxLatencyWriter struct { + dst writeFlusher + latency time.Duration // non-zero; negative means to flush immediately + + mu sync.Mutex // protects t, flushPending, and dst.Flush + t *time.Timer + flushPending bool +} + +func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + n, err = m.dst.Write(p) + if m.latency < 0 { + m.dst.Flush() + return + } + if m.flushPending { + return + } + if m.t == nil { + m.t = time.AfterFunc(m.latency, m.delayedFlush) + } else { + m.t.Reset(m.latency) + } + m.flushPending = true + return +} + +func (m *maxLatencyWriter) delayedFlush() { + m.mu.Lock() + defer m.mu.Unlock() + if !m.flushPending { // if stop was called but AfterFunc already started this goroutine + return + } + m.dst.Flush() + m.flushPending = false +} + +func (m *maxLatencyWriter) stop() { + m.mu.Lock() + defer m.mu.Unlock() + m.flushPending = false + if m.t != nil { + m.t.Stop() + } +} + +// switchProtocolCopier exists so goroutines proxying data back and +// forth have nice names in stacks. +type switchProtocolCopier struct { + user, backend io.ReadWriter +} + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + _, err := io.Copy(c.user, c.backend) + errc <- err +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + _, err := io.Copy(c.backend, c.user) + errc <- err +} diff --git a/modules/caddyhttp/reverseproxy/upstream.go b/modules/caddyhttp/reverseproxy/upstream.go deleted file mode 100755 index 1f0693e..0000000 --- a/modules/caddyhttp/reverseproxy/upstream.go +++ /dev/null @@ -1,450 +0,0 @@ -// Copyright 2015 Matthew Holt and The Caddy Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package reverseproxy implements a load-balanced reverse proxy. -package reverseproxy - -import ( - "context" - "encoding/json" - "fmt" - "math/rand" - "net" - "net/http" - "net/url" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/caddyserver/caddy/v2" - "github.com/caddyserver/caddy/v2/modules/caddyhttp" -) - -// CircuitBreaker defines the functionality of a circuit breaker module. -type CircuitBreaker interface { - Ok() bool - RecordMetric(statusCode int, latency time.Duration) -} - -type noopCircuitBreaker struct{} - -func (ncb noopCircuitBreaker) RecordMetric(statusCode int, latency time.Duration) {} -func (ncb noopCircuitBreaker) Ok() bool { - return true -} - -const ( - // TypeBalanceRoundRobin represents the value to use for configuring a load balanced reverse proxy to use round robin load balancing. - TypeBalanceRoundRobin = iota - - // TypeBalanceRandom represents the value to use for configuring a load balanced reverse proxy to use random load balancing. - TypeBalanceRandom - - // TODO: add random with two choices - - // msgNoHealthyUpstreams is returned if there are no upstreams that are healthy to proxy a request to - msgNoHealthyUpstreams = "No healthy upstreams." - - // by default perform health checks every 30 seconds - defaultHealthCheckDur = time.Second * 30 - - // used when an upstream is unhealthy, health checks can be configured to perform at a faster rate - defaultFastHealthCheckDur = time.Second * 1 -) - -var ( - // defaultTransport is the default transport to use for the reverse proxy. - defaultTransport = &http.Transport{ - Dial: (&net.Dialer{ - Timeout: 5 * time.Second, - }).Dial, - TLSHandshakeTimeout: 5 * time.Second, - } - - // defaultHTTPClient is the default http client to use for the healthchecker. - defaultHTTPClient = &http.Client{ - Timeout: time.Second * 10, - Transport: defaultTransport, - } - - // typeMap maps caddy load balance configuration to the internal representation of the loadbalance algorithm type. - typeMap = map[string]int{ - "round_robin": TypeBalanceRoundRobin, - "random": TypeBalanceRandom, - } -) - -// NewLoadBalancedReverseProxy returns a collection of Upstreams that are to be loadbalanced. -func NewLoadBalancedReverseProxy(lb *LoadBalanced, ctx caddy.Context) error { - // set defaults - if lb.NoHealthyUpstreamsMessage == "" { - lb.NoHealthyUpstreamsMessage = msgNoHealthyUpstreams - } - - if lb.TryInterval == "" { - lb.TryInterval = "20s" - } - - // set request retry interval - ti, err := time.ParseDuration(lb.TryInterval) - if err != nil { - return fmt.Errorf("NewLoadBalancedReverseProxy: %v", err.Error()) - } - lb.tryInterval = ti - - // set load balance algorithm - t, ok := typeMap[lb.LoadBalanceType] - if !ok { - t = TypeBalanceRandom - } - lb.loadBalanceType = t - - // setup each upstream - var us []*upstream - for _, uc := range lb.Upstreams { - // pass the upstream decr and incr methods to keep track of unhealthy nodes - nu, err := newUpstream(uc, lb.decrUnhealthy, lb.incrUnhealthy) - if err != nil { - return err - } - - // setup any configured circuit breakers - var cbModule = "http.handlers.reverse_proxy.circuit_breaker" - var cb CircuitBreaker - - if uc.CircuitBreaker != nil { - if _, err := caddy.GetModule(cbModule); err == nil { - val, err := ctx.LoadModule(cbModule, uc.CircuitBreaker) - if err == nil { - cbv, ok := val.(CircuitBreaker) - if ok { - cb = cbv - } else { - fmt.Printf("\nerr: %v; cannot load circuit_breaker, using noop", err.Error()) - cb = noopCircuitBreaker{} - } - } else { - fmt.Printf("\nerr: %v; cannot load circuit_breaker, using noop", err.Error()) - cb = noopCircuitBreaker{} - } - } else { - fmt.Println("circuit_breaker module not loaded, using noop") - cb = noopCircuitBreaker{} - } - } else { - cb = noopCircuitBreaker{} - } - nu.CB = cb - - // start a healthcheck worker which will periodically check to see if an upstream is healthy - // to proxy requests to. - nu.healthChecker = NewHealthCheckWorker(nu, defaultHealthCheckDur, defaultHTTPClient) - - // TODO :- if path is empty why does this empty the entire Target? - // nu.Target.Path = uc.HealthCheckPath - - nu.healthChecker.ScheduleChecks(nu.Target.String()) - lb.HealthCheckers = append(lb.HealthCheckers, nu.healthChecker) - - us = append(us, nu) - } - - lb.upstreams = us - - return nil -} - -// LoadBalanced represents a collection of upstream hosts that are loadbalanced. It -// contains multiple features like health checking and circuit breaking functionality -// for upstreams. -type LoadBalanced struct { - mu sync.Mutex - numUnhealthy int32 - selectedServer int // used during round robin load balancing - loadBalanceType int - tryInterval time.Duration - upstreams []*upstream - - // The following struct fields are set by caddy configuration. - // TryInterval is the max duration for which request retrys will be performed for a request. - TryInterval string `json:"try_interval,omitempty"` - - // Upstreams are the configs for upstream hosts - Upstreams []*UpstreamConfig `json:"upstreams,omitempty"` - - // LoadBalanceType is the string representation of what loadbalancing algorithm to use. i.e. "random" or "round_robin". - LoadBalanceType string `json:"load_balance_type,omitempty"` - - // NoHealthyUpstreamsMessage is returned as a response when there are no healthy upstreams to loadbalance to. - NoHealthyUpstreamsMessage string `json:"no_healthy_upstreams_message,omitempty"` - - // TODO :- store healthcheckers as package level state where each upstream gets a single healthchecker - // currently a healthchecker is created for each upstream defined, even if a healthchecker was previously created - // for that upstream - HealthCheckers []*HealthChecker `json:"health_checkers,omitempty"` -} - -// Cleanup stops all health checkers on a loadbalanced reverse proxy. -func (lb *LoadBalanced) Cleanup() error { - for _, hc := range lb.HealthCheckers { - hc.Stop() - } - - return nil -} - -// Provision sets up a new loadbalanced reverse proxy. -func (lb *LoadBalanced) Provision(ctx caddy.Context) error { - return NewLoadBalancedReverseProxy(lb, ctx) -} - -// ServeHTTP implements the caddyhttp.MiddlewareHandler interface to -// dispatch an HTTP request to the proper server. -func (lb *LoadBalanced) ServeHTTP(w http.ResponseWriter, r *http.Request, _ caddyhttp.Handler) error { - // ensure requests don't hang if an upstream does not respond or is not eventually healthy - var u *upstream - var done bool - - retryTimer := time.NewTicker(lb.tryInterval) - defer retryTimer.Stop() - - go func() { - select { - case <-retryTimer.C: - done = true - } - }() - - // keep trying to get an available upstream to process the request - for { - switch lb.loadBalanceType { - case TypeBalanceRandom: - u = lb.random() - case TypeBalanceRoundRobin: - u = lb.roundRobin() - } - - // if we can't get an upstream and our retry interval has ended return an error response - if u == nil && done { - w.WriteHeader(http.StatusBadGateway) - fmt.Fprint(w, lb.NoHealthyUpstreamsMessage) - - return fmt.Errorf(msgNoHealthyUpstreams) - } - - // attempt to get an available upstream - if u == nil { - continue - } - - start := time.Now() - - // if we get an error retry until we get a healthy upstream - res, err := u.ReverseProxy.ServeHTTP(w, r) - if err != nil { - if err == context.Canceled { - return nil - } - - continue - } - - // record circuit breaker metrics - go u.CB.RecordMetric(res.StatusCode, time.Now().Sub(start)) - - return nil - } -} - -// incrUnhealthy increments the amount of unhealthy nodes in a loadbalancer. -func (lb *LoadBalanced) incrUnhealthy() { - atomic.AddInt32(&lb.numUnhealthy, 1) -} - -// decrUnhealthy decrements the amount of unhealthy nodes in a loadbalancer. -func (lb *LoadBalanced) decrUnhealthy() { - atomic.AddInt32(&lb.numUnhealthy, -1) -} - -// roundRobin implements a round robin load balancing algorithm to select -// which server to forward requests to. -func (lb *LoadBalanced) roundRobin() *upstream { - if atomic.LoadInt32(&lb.numUnhealthy) == int32(len(lb.upstreams)) { - return nil - } - - selected := lb.upstreams[lb.selectedServer] - - lb.mu.Lock() - lb.selectedServer++ - if lb.selectedServer >= len(lb.upstreams) { - lb.selectedServer = 0 - } - lb.mu.Unlock() - - if selected.IsHealthy() && selected.CB.Ok() { - return selected - } - - return nil -} - -// random implements a random server selector for load balancing. -func (lb *LoadBalanced) random() *upstream { - if atomic.LoadInt32(&lb.numUnhealthy) == int32(len(lb.upstreams)) { - return nil - } - - n := rand.Int() % len(lb.upstreams) - selected := lb.upstreams[n] - - if selected.IsHealthy() && selected.CB.Ok() { - return selected - } - - return nil -} - -// UpstreamConfig represents the config of an upstream. -type UpstreamConfig struct { - // Host is the host name of the upstream server. - Host string `json:"host,omitempty"` - - // FastHealthCheckDuration is the duration for which a health check is performed when a node is considered unhealthy. - FastHealthCheckDuration string `json:"fast_health_check_duration,omitempty"` - - CircuitBreaker json.RawMessage `json:"circuit_breaker,omitempty"` - - // // CircuitBreakerConfig is the config passed to setup a circuit breaker. - // CircuitBreakerConfig *circuitbreaker.Config `json:"circuit_breaker,omitempty"` - circuitbreaker CircuitBreaker - - // HealthCheckDuration is the default duration for which a health check is performed. - HealthCheckDuration string `json:"health_check_duration,omitempty"` - - // HealthCheckPath is the path at the upstream host to use for healthchecks. - HealthCheckPath string `json:"health_check_path,omitempty"` -} - -// upstream represents an upstream host. -type upstream struct { - Healthy int32 // 0 = false, 1 = true - Target *url.URL - ReverseProxy *ReverseProxy - Incr func() - Decr func() - CB CircuitBreaker - healthChecker *HealthChecker - healthCheckDur time.Duration - fastHealthCheckDur time.Duration -} - -// newUpstream returns a new upstream. -func newUpstream(uc *UpstreamConfig, d func(), i func()) (*upstream, error) { - host := strings.TrimSpace(uc.Host) - protoIdx := strings.Index(host, "://") - if protoIdx == -1 || len(host[:protoIdx]) == 0 { - return nil, fmt.Errorf("protocol is required for host") - } - - hostURL, err := url.Parse(host) - if err != nil { - return nil, err - } - - // parse healthcheck durations - hcd, err := time.ParseDuration(uc.HealthCheckDuration) - if err != nil { - hcd = defaultHealthCheckDur - } - - fhcd, err := time.ParseDuration(uc.FastHealthCheckDuration) - if err != nil { - fhcd = defaultFastHealthCheckDur - } - - u := upstream{ - healthCheckDur: hcd, - fastHealthCheckDur: fhcd, - Target: hostURL, - Decr: d, - Incr: i, - Healthy: int32(0), // assume is unhealthy on start - } - - u.ReverseProxy = newReverseProxy(hostURL, u.SetHealthiness) - return &u, nil -} - -// SetHealthiness sets whether an upstream is healthy or not. The health check worker is updated to -// perform checks faster if a node is unhealthy. -func (u *upstream) SetHealthiness(ok bool) { - h := atomic.LoadInt32(&u.Healthy) - var wasHealthy bool - if h == 1 { - wasHealthy = true - } else { - wasHealthy = false - } - - if ok { - u.healthChecker.Ticker = time.NewTicker(u.healthCheckDur) - - if !wasHealthy { - atomic.AddInt32(&u.Healthy, 1) - u.Decr() - } - } else { - u.healthChecker.Ticker = time.NewTicker(u.fastHealthCheckDur) - - if wasHealthy { - atomic.AddInt32(&u.Healthy, -1) - u.Incr() - } - } -} - -// IsHealthy returns whether an Upstream is healthy or not. -func (u *upstream) IsHealthy() bool { - i := atomic.LoadInt32(&u.Healthy) - if i == 1 { - return true - } - - return false -} - -// newReverseProxy returns a new reverse proxy handler. -func newReverseProxy(target *url.URL, setHealthiness func(bool)) *ReverseProxy { - errorHandler := func(w http.ResponseWriter, r *http.Request, err error) { - // we don't need to worry about cancelled contexts since this doesn't necessarilly mean that - // the upstream is unhealthy. - if err != context.Canceled { - setHealthiness(false) - } - } - - rp := NewSingleHostReverseProxy(target) - rp.ErrorHandler = errorHandler - rp.Transport = defaultTransport // use default transport that times out in 5 seconds - return rp -} - -// Interface guards -var ( - _ caddyhttp.MiddlewareHandler = (*LoadBalanced)(nil) - _ caddy.Provisioner = (*LoadBalanced)(nil) - _ caddy.CleanerUpper = (*LoadBalanced)(nil) -) |