summaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorMatthew Holt <mholt@users.noreply.github.com>2019-09-02 22:01:02 -0600
committerMatthew Holt <mholt@users.noreply.github.com>2019-09-02 22:01:02 -0600
commit026df7c5cb33331d223afc6a9599274e8c89dfd9 (patch)
treee7101ef1197b26b02472a567998fb83b57a45f08 /modules
parent2dc4fcc62bfab639b60758f26659c03e58af9960 (diff)
reverse_proxy: WIP refactor and support for FastCGI
Diffstat (limited to 'modules')
-rw-r--r--modules/caddyhttp/caddyhttp.go14
-rw-r--r--modules/caddyhttp/matchers.go5
-rw-r--r--modules/caddyhttp/reverseproxy/fastcgi/client.go578
-rw-r--r--modules/caddyhttp/reverseproxy/fastcgi/client_test.go301
-rw-r--r--modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go342
-rwxr-xr-xmodules/caddyhttp/reverseproxy/healthchecker.go86
-rw-r--r--modules/caddyhttp/reverseproxy/httptransport.go133
-rwxr-xr-xmodules/caddyhttp/reverseproxy/module.go53
-rw-r--r--[-rwxr-xr-x]modules/caddyhttp/reverseproxy/reverseproxy.go846
-rw-r--r--modules/caddyhttp/reverseproxy/selectionpolicies.go351
-rw-r--r--modules/caddyhttp/reverseproxy/selectionpolicies_test.go363
-rwxr-xr-xmodules/caddyhttp/reverseproxy/upstream.go450
12 files changed, 2649 insertions, 873 deletions
diff --git a/modules/caddyhttp/caddyhttp.go b/modules/caddyhttp/caddyhttp.go
index b4b1ec6..300b5fd 100644
--- a/modules/caddyhttp/caddyhttp.go
+++ b/modules/caddyhttp/caddyhttp.go
@@ -518,6 +518,20 @@ func (ws WeakString) String() string {
return string(ws)
}
+// StatusCodeMatches returns true if a real HTTP status code matches
+// the configured status code, which may be either a real HTTP status
+// code or an integer representing a class of codes (e.g. 4 for all
+// 4xx statuses).
+func StatusCodeMatches(actual, configured int) bool {
+ if actual == configured {
+ return true
+ }
+ if configured < 100 && actual >= configured*100 && actual < (configured+1)*100 {
+ return true
+ }
+ return false
+}
+
const (
// DefaultHTTPPort is the default port for HTTP.
DefaultHTTPPort = 80
diff --git a/modules/caddyhttp/matchers.go b/modules/caddyhttp/matchers.go
index 0dac151..2ddefd0 100644
--- a/modules/caddyhttp/matchers.go
+++ b/modules/caddyhttp/matchers.go
@@ -616,10 +616,7 @@ func (rm ResponseMatcher) matchStatusCode(statusCode int) bool {
return true
}
for _, code := range rm.StatusCode {
- if statusCode == code {
- return true
- }
- if code < 100 && statusCode >= code*100 && statusCode < (code+1)*100 {
+ if StatusCodeMatches(statusCode, code) {
return true
}
}
diff --git a/modules/caddyhttp/reverseproxy/fastcgi/client.go b/modules/caddyhttp/reverseproxy/fastcgi/client.go
new file mode 100644
index 0000000..ae0de00
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/fastcgi/client.go
@@ -0,0 +1,578 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Forked Jan. 2015 from http://bitbucket.org/PinIdea/fcgi_client
+// (which is forked from https://code.google.com/p/go-fastcgi-client/).
+// This fork contains several fixes and improvements by Matt Holt and
+// other contributors to the Caddy project.
+
+// Copyright 2012 Junqing Tan <ivan@mysqlab.net> and The Go Authors
+// Use of this source code is governed by a BSD-style
+// Part of source code is from Go fcgi package
+
+package fastcgi
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/binary"
+ "errors"
+ "io"
+ "io/ioutil"
+ "mime/multipart"
+ "net"
+ "net/http"
+ "net/http/httputil"
+ "net/textproto"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+)
+
+// FCGIListenSockFileno describes listen socket file number.
+const FCGIListenSockFileno uint8 = 0
+
+// FCGIHeaderLen describes header length.
+const FCGIHeaderLen uint8 = 8
+
+// Version1 describes the version.
+const Version1 uint8 = 1
+
+// FCGINullRequestID describes the null request ID.
+const FCGINullRequestID uint8 = 0
+
+// FCGIKeepConn describes keep connection mode.
+const FCGIKeepConn uint8 = 1
+
+const (
+ // BeginRequest is the begin request flag.
+ BeginRequest uint8 = iota + 1
+ // AbortRequest is the abort request flag.
+ AbortRequest
+ // EndRequest is the end request flag.
+ EndRequest
+ // Params is the parameters flag.
+ Params
+ // Stdin is the standard input flag.
+ Stdin
+ // Stdout is the standard output flag.
+ Stdout
+ // Stderr is the standard error flag.
+ Stderr
+ // Data is the data flag.
+ Data
+ // GetValues is the get values flag.
+ GetValues
+ // GetValuesResult is the get values result flag.
+ GetValuesResult
+ // UnknownType is the unknown type flag.
+ UnknownType
+ // MaxType is the maximum type flag.
+ MaxType = UnknownType
+)
+
+const (
+ // Responder is the responder flag.
+ Responder uint8 = iota + 1
+ // Authorizer is the authorizer flag.
+ Authorizer
+ // Filter is the filter flag.
+ Filter
+)
+
+const (
+ // RequestComplete is the completed request flag.
+ RequestComplete uint8 = iota
+ // CantMultiplexConns is the multiplexed connections flag.
+ CantMultiplexConns
+ // Overloaded is the overloaded flag.
+ Overloaded
+ // UnknownRole is the unknown role flag.
+ UnknownRole
+)
+
+const (
+ // MaxConns is the maximum connections flag.
+ MaxConns string = "MAX_CONNS"
+ // MaxRequests is the maximum requests flag.
+ MaxRequests string = "MAX_REQS"
+ // MultiplexConns is the multiplex connections flag.
+ MultiplexConns string = "MPXS_CONNS"
+)
+
+const (
+ maxWrite = 65500 // 65530 may work, but for compatibility
+ maxPad = 255
+)
+
+type header struct {
+ Version uint8
+ Type uint8
+ ID uint16
+ ContentLength uint16
+ PaddingLength uint8
+ Reserved uint8
+}
+
+// for padding so we don't have to allocate all the time
+// not synchronized because we don't care what the contents are
+var pad [maxPad]byte
+
+func (h *header) init(recType uint8, reqID uint16, contentLength int) {
+ h.Version = 1
+ h.Type = recType
+ h.ID = reqID
+ h.ContentLength = uint16(contentLength)
+ h.PaddingLength = uint8(-contentLength & 7)
+}
+
+type record struct {
+ h header
+ rbuf []byte
+}
+
+func (rec *record) read(r io.Reader) (buf []byte, err error) {
+ if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil {
+ return
+ }
+ if rec.h.Version != 1 {
+ err = errors.New("fcgi: invalid header version")
+ return
+ }
+ if rec.h.Type == EndRequest {
+ err = io.EOF
+ return
+ }
+ n := int(rec.h.ContentLength) + int(rec.h.PaddingLength)
+ if len(rec.rbuf) < n {
+ rec.rbuf = make([]byte, n)
+ }
+ if _, err = io.ReadFull(r, rec.rbuf[:n]); err != nil {
+ return
+ }
+ buf = rec.rbuf[:int(rec.h.ContentLength)]
+
+ return
+}
+
+// FCGIClient implements a FastCGI client, which is a standard for
+// interfacing external applications with Web servers.
+type FCGIClient struct {
+ mutex sync.Mutex
+ rwc io.ReadWriteCloser
+ h header
+ buf bytes.Buffer
+ stderr bytes.Buffer
+ keepAlive bool
+ reqID uint16
+}
+
+// DialWithDialerContext connects to the fcgi responder at the specified network address, using custom net.Dialer
+// and a context.
+// See func net.Dial for a description of the network and address parameters.
+func DialWithDialerContext(ctx context.Context, network, address string, dialer net.Dialer) (fcgi *FCGIClient, err error) {
+ var conn net.Conn
+ conn, err = dialer.DialContext(ctx, network, address)
+ if err != nil {
+ return
+ }
+
+ fcgi = &FCGIClient{
+ rwc: conn,
+ keepAlive: false,
+ reqID: 1,
+ }
+
+ return
+}
+
+// DialContext is like Dial but passes ctx to dialer.Dial.
+func DialContext(ctx context.Context, network, address string) (fcgi *FCGIClient, err error) {
+ // TODO: why not set timeout here?
+ return DialWithDialerContext(ctx, network, address, net.Dialer{})
+}
+
+// Dial connects to the fcgi responder at the specified network address, using default net.Dialer.
+// See func net.Dial for a description of the network and address parameters.
+func Dial(network, address string) (fcgi *FCGIClient, err error) {
+ return DialContext(context.Background(), network, address)
+}
+
+// Close closes fcgi connection
+func (c *FCGIClient) Close() {
+ c.rwc.Close()
+}
+
+func (c *FCGIClient) writeRecord(recType uint8, content []byte) (err error) {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ c.buf.Reset()
+ c.h.init(recType, c.reqID, len(content))
+ if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil {
+ return err
+ }
+ if _, err := c.buf.Write(content); err != nil {
+ return err
+ }
+ if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil {
+ return err
+ }
+ _, err = c.rwc.Write(c.buf.Bytes())
+ return err
+}
+
+func (c *FCGIClient) writeBeginRequest(role uint16, flags uint8) error {
+ b := [8]byte{byte(role >> 8), byte(role), flags}
+ return c.writeRecord(BeginRequest, b[:])
+}
+
+func (c *FCGIClient) writeEndRequest(appStatus int, protocolStatus uint8) error {
+ b := make([]byte, 8)
+ binary.BigEndian.PutUint32(b, uint32(appStatus))
+ b[4] = protocolStatus
+ return c.writeRecord(EndRequest, b)
+}
+
+func (c *FCGIClient) writePairs(recType uint8, pairs map[string]string) error {
+ w := newWriter(c, recType)
+ b := make([]byte, 8)
+ nn := 0
+ for k, v := range pairs {
+ m := 8 + len(k) + len(v)
+ if m > maxWrite {
+ // param data size exceed 65535 bytes"
+ vl := maxWrite - 8 - len(k)
+ v = v[:vl]
+ }
+ n := encodeSize(b, uint32(len(k)))
+ n += encodeSize(b[n:], uint32(len(v)))
+ m = n + len(k) + len(v)
+ if (nn + m) > maxWrite {
+ w.Flush()
+ nn = 0
+ }
+ nn += m
+ if _, err := w.Write(b[:n]); err != nil {
+ return err
+ }
+ if _, err := w.WriteString(k); err != nil {
+ return err
+ }
+ if _, err := w.WriteString(v); err != nil {
+ return err
+ }
+ }
+ w.Close()
+ return nil
+}
+
+func encodeSize(b []byte, size uint32) int {
+ if size > 127 {
+ size |= 1 << 31
+ binary.BigEndian.PutUint32(b, size)
+ return 4
+ }
+ b[0] = byte(size)
+ return 1
+}
+
+// bufWriter encapsulates bufio.Writer but also closes the underlying stream when
+// Closed.
+type bufWriter struct {
+ closer io.Closer
+ *bufio.Writer
+}
+
+func (w *bufWriter) Close() error {
+ if err := w.Writer.Flush(); err != nil {
+ w.closer.Close()
+ return err
+ }
+ return w.closer.Close()
+}
+
+func newWriter(c *FCGIClient, recType uint8) *bufWriter {
+ s := &streamWriter{c: c, recType: recType}
+ w := bufio.NewWriterSize(s, maxWrite)
+ return &bufWriter{s, w}
+}
+
+// streamWriter abstracts out the separation of a stream into discrete records.
+// It only writes maxWrite bytes at a time.
+type streamWriter struct {
+ c *FCGIClient
+ recType uint8
+}
+
+func (w *streamWriter) Write(p []byte) (int, error) {
+ nn := 0
+ for len(p) > 0 {
+ n := len(p)
+ if n > maxWrite {
+ n = maxWrite
+ }
+ if err := w.c.writeRecord(w.recType, p[:n]); err != nil {
+ return nn, err
+ }
+ nn += n
+ p = p[n:]
+ }
+ return nn, nil
+}
+
+func (w *streamWriter) Close() error {
+ // send empty record to close the stream
+ return w.c.writeRecord(w.recType, nil)
+}
+
+type streamReader struct {
+ c *FCGIClient
+ buf []byte
+}
+
+func (w *streamReader) Read(p []byte) (n int, err error) {
+
+ if len(p) > 0 {
+ if len(w.buf) == 0 {
+
+ // filter outputs for error log
+ for {
+ rec := &record{}
+ var buf []byte
+ buf, err = rec.read(w.c.rwc)
+ if err != nil {
+ return
+ }
+ // standard error output
+ if rec.h.Type == Stderr {
+ w.c.stderr.Write(buf)
+ continue
+ }
+ w.buf = buf
+ break
+ }
+ }
+
+ n = len(p)
+ if n > len(w.buf) {
+ n = len(w.buf)
+ }
+ copy(p, w.buf[:n])
+ w.buf = w.buf[n:]
+ }
+
+ return
+}
+
+// Do made the request and returns a io.Reader that translates the data read
+// from fcgi responder out of fcgi packet before returning it.
+func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err error) {
+ err = c.writeBeginRequest(uint16(Responder), 0)
+ if err != nil {
+ return
+ }
+
+ err = c.writePairs(Params, p)
+ if err != nil {
+ return
+ }
+
+ body := newWriter(c, Stdin)
+ if req != nil {
+ _, _ = io.Copy(body, req)
+ }
+ body.Close()
+
+ r = &streamReader{c: c}
+ return
+}
+
+// clientCloser is a io.ReadCloser. It wraps a io.Reader with a Closer
+// that closes FCGIClient connection.
+type clientCloser struct {
+ *FCGIClient
+ io.Reader
+}
+
+func (f clientCloser) Close() error { return f.rwc.Close() }
+
+// Request returns a HTTP Response with Header and Body
+// from fcgi responder
+func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Response, err error) {
+ r, err := c.Do(p, req)
+ if err != nil {
+ return
+ }
+
+ rb := bufio.NewReader(r)
+ tp := textproto.NewReader(rb)
+ resp = new(http.Response)
+
+ // Parse the response headers.
+ mimeHeader, err := tp.ReadMIMEHeader()
+ if err != nil && err != io.EOF {
+ return
+ }
+ resp.Header = http.Header(mimeHeader)
+
+ if resp.Header.Get("Status") != "" {
+ statusParts := strings.SplitN(resp.Header.Get("Status"), " ", 2)
+ resp.StatusCode, err = strconv.Atoi(statusParts[0])
+ if err != nil {
+ return
+ }
+ if len(statusParts) > 1 {
+ resp.Status = statusParts[1]
+ }
+
+ } else {
+ resp.StatusCode = http.StatusOK
+ }
+
+ // TODO: fixTransferEncoding ?
+ resp.TransferEncoding = resp.Header["Transfer-Encoding"]
+ resp.ContentLength, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
+
+ if chunked(resp.TransferEncoding) {
+ resp.Body = clientCloser{c, httputil.NewChunkedReader(rb)}
+ } else {
+ resp.Body = clientCloser{c, ioutil.NopCloser(rb)}
+ }
+ return
+}
+
+// Get issues a GET request to the fcgi responder.
+func (c *FCGIClient) Get(p map[string]string, body io.Reader, l int64) (resp *http.Response, err error) {
+
+ p["REQUEST_METHOD"] = "GET"
+ p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10)
+
+ return c.Request(p, body)
+}
+
+// Head issues a HEAD request to the fcgi responder.
+func (c *FCGIClient) Head(p map[string]string) (resp *http.Response, err error) {
+
+ p["REQUEST_METHOD"] = "HEAD"
+ p["CONTENT_LENGTH"] = "0"
+
+ return c.Request(p, nil)
+}
+
+// Options issues an OPTIONS request to the fcgi responder.
+func (c *FCGIClient) Options(p map[string]string) (resp *http.Response, err error) {
+
+ p["REQUEST_METHOD"] = "OPTIONS"
+ p["CONTENT_LENGTH"] = "0"
+
+ return c.Request(p, nil)
+}
+
+// Post issues a POST request to the fcgi responder. with request body
+// in the format that bodyType specified
+func (c *FCGIClient) Post(p map[string]string, method string, bodyType string, body io.Reader, l int64) (resp *http.Response, err error) {
+ if p == nil {
+ p = make(map[string]string)
+ }
+
+ p["REQUEST_METHOD"] = strings.ToUpper(method)
+
+ if len(p["REQUEST_METHOD"]) == 0 || p["REQUEST_METHOD"] == "GET" {
+ p["REQUEST_METHOD"] = "POST"
+ }
+
+ p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10)
+ if len(bodyType) > 0 {
+ p["CONTENT_TYPE"] = bodyType
+ } else {
+ p["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
+ }
+
+ return c.Request(p, body)
+}
+
+// PostForm issues a POST to the fcgi responder, with form
+// as a string key to a list values (url.Values)
+func (c *FCGIClient) PostForm(p map[string]string, data url.Values) (resp *http.Response, err error) {
+ body := bytes.NewReader([]byte(data.Encode()))
+ return c.Post(p, "POST", "application/x-www-form-urlencoded", body, int64(body.Len()))
+}
+
+// PostFile issues a POST to the fcgi responder in multipart(RFC 2046) standard,
+// with form as a string key to a list values (url.Values),
+// and/or with file as a string key to a list file path.
+func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[string]string) (resp *http.Response, err error) {
+ buf := &bytes.Buffer{}
+ writer := multipart.NewWriter(buf)
+ bodyType := writer.FormDataContentType()
+
+ for key, val := range data {
+ for _, v0 := range val {
+ err = writer.WriteField(key, v0)
+ if err != nil {
+ return
+ }
+ }
+ }
+
+ for key, val := range file {
+ fd, e := os.Open(val)
+ if e != nil {
+ return nil, e
+ }
+ defer fd.Close()
+
+ part, e := writer.CreateFormFile(key, filepath.Base(val))
+ if e != nil {
+ return nil, e
+ }
+ _, err = io.Copy(part, fd)
+ if err != nil {
+ return
+ }
+ }
+
+ err = writer.Close()
+ if err != nil {
+ return
+ }
+
+ return c.Post(p, "POST", bodyType, buf, int64(buf.Len()))
+}
+
+// SetReadTimeout sets the read timeout for future calls that read from the
+// fcgi responder. A zero value for t means no timeout will be set.
+func (c *FCGIClient) SetReadTimeout(t time.Duration) error {
+ if conn, ok := c.rwc.(net.Conn); ok && t != 0 {
+ return conn.SetReadDeadline(time.Now().Add(t))
+ }
+ return nil
+}
+
+// SetWriteTimeout sets the write timeout for future calls that send data to
+// the fcgi responder. A zero value for t means no timeout will be set.
+func (c *FCGIClient) SetWriteTimeout(t time.Duration) error {
+ if conn, ok := c.rwc.(net.Conn); ok && t != 0 {
+ return conn.SetWriteDeadline(time.Now().Add(t))
+ }
+ return nil
+}
+
+// Checks whether chunked is part of the encodings stack
+func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }
diff --git a/modules/caddyhttp/reverseproxy/fastcgi/client_test.go b/modules/caddyhttp/reverseproxy/fastcgi/client_test.go
new file mode 100644
index 0000000..c090f3c
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/fastcgi/client_test.go
@@ -0,0 +1,301 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// NOTE: These tests were adapted from the original
+// repository from which this package was forked.
+// The tests are slow (~10s) and in dire need of rewriting.
+// As such, the tests have been disabled to speed up
+// automated builds until they can be properly written.
+
+package fastcgi
+
+import (
+ "bytes"
+ "crypto/md5"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "math/rand"
+ "net"
+ "net/http"
+ "net/http/fcgi"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+// test fcgi protocol includes:
+// Get, Post, Post in multipart/form-data, and Post with files
+// each key should be the md5 of the value or the file uploaded
+// specify remote fcgi responder ip:port to test with php
+// test failed if the remote fcgi(script) failed md5 verification
+// and output "FAILED" in response
+const (
+ scriptFile = "/tank/www/fcgic_test.php"
+ //ipPort = "remote-php-serv:59000"
+ ipPort = "127.0.0.1:59000"
+)
+
+var globalt *testing.T
+
+type FastCGIServer struct{}
+
+func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
+
+ if err := req.ParseMultipartForm(100000000); err != nil {
+ log.Printf("[ERROR] failed to parse: %v", err)
+ }
+
+ stat := "PASSED"
+ fmt.Fprintln(resp, "-")
+ fileNum := 0
+ {
+ length := 0
+ for k0, v0 := range req.Form {
+ h := md5.New()
+ _, _ = io.WriteString(h, v0[0])
+ _md5 := fmt.Sprintf("%x", h.Sum(nil))
+
+ length += len(k0)
+ length += len(v0[0])
+
+ // echo error when key != _md5(val)
+ if _md5 != k0 {
+ fmt.Fprintln(resp, "server:err ", _md5, k0)
+ stat = "FAILED"
+ }
+ }
+ if req.MultipartForm != nil {
+ fileNum = len(req.MultipartForm.File)
+ for kn, fns := range req.MultipartForm.File {
+ //fmt.Fprintln(resp, "server:filekey ", kn )
+ length += len(kn)
+ for _, f := range fns {
+ fd, err := f.Open()
+ if err != nil {
+ log.Println("server:", err)
+ return
+ }
+ h := md5.New()
+ l0, err := io.Copy(h, fd)
+ if err != nil {
+ log.Println(err)
+ return
+ }
+ length += int(l0)
+ defer fd.Close()
+ md5 := fmt.Sprintf("%x", h.Sum(nil))
+ //fmt.Fprintln(resp, "server:filemd5 ", md5 )
+
+ if kn != md5 {
+ fmt.Fprintln(resp, "server:err ", md5, kn)
+ stat = "FAILED"
+ }
+ //fmt.Fprintln(resp, "server:filename ", f.Filename )
+ }
+ }
+ }
+
+ fmt.Fprintln(resp, "server:got data length", length)
+ }
+ fmt.Fprintln(resp, "-"+stat+"-POST(", len(req.Form), ")-FILE(", fileNum, ")--")
+}
+
+func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) {
+ fcgi, err := Dial("tcp", ipPort)
+ if err != nil {
+ log.Println("err:", err)
+ return
+ }
+
+ length := 0
+
+ var resp *http.Response
+ switch reqType {
+ case 0:
+ if len(data) > 0 {
+ length = len(data)
+ rd := bytes.NewReader(data)
+ resp, err = fcgi.Post(fcgiParams, "", "", rd, int64(rd.Len()))
+ } else if len(posts) > 0 {
+ values := url.Values{}
+ for k, v := range posts {
+ values.Set(k, v)
+ length += len(k) + 2 + len(v)
+ }
+ resp, err = fcgi.PostForm(fcgiParams, values)
+ } else {
+ rd := bytes.NewReader(data)
+ resp, err = fcgi.Get(fcgiParams, rd, int64(rd.Len()))
+ }
+
+ default:
+ values := url.Values{}
+ for k, v := range posts {
+ values.Set(k, v)
+ length += len(k) + 2 + len(v)
+ }
+
+ for k, v := range files {
+ fi, _ := os.Lstat(v)
+ length += len(k) + int(fi.Size())
+ }
+ resp, err = fcgi.PostFile(fcgiParams, values, files)
+ }
+
+ if err != nil {
+ log.Println("err:", err)
+ return
+ }
+
+ defer resp.Body.Close()
+ content, _ = ioutil.ReadAll(resp.Body)
+
+ log.Println("c: send data length ≈", length, string(content))
+ fcgi.Close()
+ time.Sleep(1 * time.Second)
+
+ if bytes.Contains(content, []byte("FAILED")) {
+ globalt.Error("Server return failed message")
+ }
+
+ return
+}
+
+func generateRandFile(size int) (p string, m string) {
+
+ p = filepath.Join(os.TempDir(), "fcgict"+strconv.Itoa(rand.Int()))
+
+ // open output file
+ fo, err := os.Create(p)
+ if err != nil {
+ panic(err)
+ }
+ // close fo on exit and check for its returned error
+ defer func() {
+ if err := fo.Close(); err != nil {
+ panic(err)
+ }
+ }()
+
+ h := md5.New()
+ for i := 0; i < size/16; i++ {
+ buf := make([]byte, 16)
+ binary.PutVarint(buf, rand.Int63())
+ if _, err := fo.Write(buf); err != nil {
+ log.Printf("[ERROR] failed to write buffer: %v\n", err)
+ }
+ if _, err := h.Write(buf); err != nil {
+ log.Printf("[ERROR] failed to write buffer: %v\n", err)
+ }
+ }
+ m = fmt.Sprintf("%x", h.Sum(nil))
+ return
+}
+
+func DisabledTest(t *testing.T) {
+ // TODO: test chunked reader
+ globalt = t
+
+ rand.Seed(time.Now().UTC().UnixNano())
+
+ // server
+ go func() {
+ listener, err := net.Listen("tcp", ipPort)
+ if err != nil {
+ log.Println("listener creation failed: ", err)
+ }
+
+ srv := new(FastCGIServer)
+ if err := fcgi.Serve(listener, srv); err != nil {
+ log.Print("[ERROR] failed to start server: ", err)
+ }
+ }()
+
+ time.Sleep(1 * time.Second)
+
+ // init
+ fcgiParams := make(map[string]string)
+ fcgiParams["REQUEST_METHOD"] = "GET"
+ fcgiParams["SERVER_PROTOCOL"] = "HTTP/1.1"
+ //fcgi_params["GATEWAY_INTERFACE"] = "CGI/1.1"
+ fcgiParams["SCRIPT_FILENAME"] = scriptFile
+
+ // simple GET
+ log.Println("test:", "get")
+ sendFcgi(0, fcgiParams, nil, nil, nil)
+
+ // simple post data
+ log.Println("test:", "post")
+ sendFcgi(0, fcgiParams, []byte("c4ca4238a0b923820dcc509a6f75849b=1&7b8b965ad4bca0e41ab51de7b31363a1=n"), nil, nil)
+
+ log.Println("test:", "post data (more than 60KB)")
+ data := ""
+ for i := 0x00; i < 0xff; i++ {
+ v0 := strings.Repeat(string(i), 256)
+ h := md5.New()
+ _, _ = io.WriteString(h, v0)
+ k0 := fmt.Sprintf("%x", h.Sum(nil))
+ data += k0 + "=" + url.QueryEscape(v0) + "&"
+ }
+ sendFcgi(0, fcgiParams, []byte(data), nil, nil)
+
+ log.Println("test:", "post form (use url.Values)")
+ p0 := make(map[string]string, 1)
+ p0["c4ca4238a0b923820dcc509a6f75849b"] = "1"
+ p0["7b8b965ad4bca0e41ab51de7b31363a1"] = "n"
+ sendFcgi(1, fcgiParams, nil, p0, nil)
+
+ log.Println("test:", "post forms (256 keys, more than 1MB)")
+ p1 := make(map[string]string, 1)
+ for i := 0x00; i < 0xff; i++ {
+ v0 := strings.Repeat(string(i), 4096)
+ h := md5.New()
+ _, _ = io.WriteString(h, v0)
+ k0 := fmt.Sprintf("%x", h.Sum(nil))
+ p1[k0] = v0
+ }
+ sendFcgi(1, fcgiParams, nil, p1, nil)
+
+ log.Println("test:", "post file (1 file, 500KB)) ")
+ f0 := make(map[string]string, 1)
+ path0, m0 := generateRandFile(500000)
+ f0[m0] = path0
+ sendFcgi(1, fcgiParams, nil, p1, f0)
+
+ log.Println("test:", "post multiple files (2 files, 5M each) and forms (256 keys, more than 1MB data")
+ path1, m1 := generateRandFile(5000000)
+ f0[m1] = path1
+ sendFcgi(1, fcgiParams, nil, p1, f0)
+
+ log.Println("test:", "post only files (2 files, 5M each)")
+ sendFcgi(1, fcgiParams, nil, nil, f0)
+
+ log.Println("test:", "post only 1 file")
+ delete(f0, "m0")
+ sendFcgi(1, fcgiParams, nil, nil, f0)
+
+ if err := os.Remove(path0); err != nil {
+ log.Println("[ERROR] failed to remove path: ", err)
+ }
+ if err := os.Remove(path1); err != nil {
+ log.Println("[ERROR] failed to remove path: ", err)
+ }
+}
diff --git a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go
new file mode 100644
index 0000000..32f094b
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go
@@ -0,0 +1,342 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fastcgi
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "net/http"
+ "path"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/caddyserver/caddy/v2/modules/caddytls"
+
+ "github.com/caddyserver/caddy/v2"
+)
+
+func init() {
+ caddy.RegisterModule(Transport{})
+}
+
+type Transport struct {
+ //////////////////////////////
+ // TODO: taken from v1 Handler type
+
+ SoftwareName string
+ SoftwareVersion string
+ ServerName string
+ ServerPort string
+
+ //////////////////////////
+ // TODO: taken from v1 Rule type
+
+ // The base path to match. Required.
+ // Path string
+
+ // upstream load balancer
+ // balancer
+
+ // Always process files with this extension with fastcgi.
+ // Ext string
+
+ // Use this directory as the fastcgi root directory. Defaults to the root
+ // directory of the parent virtual host.
+ Root string
+
+ // The path in the URL will be split into two, with the first piece ending
+ // with the value of SplitPath. The first piece will be assumed as the
+ // actual resource (CGI script) name, and the second piece will be set to
+ // PATH_INFO for the CGI script to use.
+ SplitPath string
+
+ // If the URL ends with '/' (which indicates a directory), these index
+ // files will be tried instead.
+ IndexFiles []string
+
+ // Environment Variables
+ EnvVars [][2]string
+
+ // Ignored paths
+ IgnoredSubPaths []string
+
+ // The duration used to set a deadline when connecting to an upstream.
+ DialTimeout time.Duration
+
+ // The duration used to set a deadline when reading from the FastCGI server.
+ ReadTimeout time.Duration
+
+ // The duration used to set a deadline when sending to the FastCGI server.
+ WriteTimeout time.Duration
+}
+
+// CaddyModule returns the Caddy module information.
+func (Transport) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.transport.fastcgi",
+ New: func() caddy.Module { return new(Transport) },
+ }
+}
+
+func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
+ // Create environment for CGI script
+ env, err := t.buildEnv(r)
+ if err != nil {
+ return nil, fmt.Errorf("building environment: %v", err)
+ }
+
+ // TODO:
+ // Connect to FastCGI gateway
+ // address, err := f.Address()
+ // if err != nil {
+ // return http.StatusBadGateway, err
+ // }
+ // network, address := parseAddress(address)
+ network, address := "tcp", r.URL.Host // TODO:
+
+ ctx := context.Background()
+ if t.DialTimeout > 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, t.DialTimeout)
+ defer cancel()
+ }
+
+ fcgiBackend, err := DialContext(ctx, network, address)
+ if err != nil {
+ return nil, fmt.Errorf("dialing backend: %v", err)
+ }
+ // fcgiBackend is closed when response body is closed (see clientCloser)
+
+ // read/write timeouts
+ if err := fcgiBackend.SetReadTimeout(t.ReadTimeout); err != nil {
+ return nil, fmt.Errorf("setting read timeout: %v", err)
+ }
+ if err := fcgiBackend.SetWriteTimeout(t.WriteTimeout); err != nil {
+ return nil, fmt.Errorf("setting write timeout: %v", err)
+ }
+
+ var resp *http.Response
+
+ var contentLength int64
+ // if ContentLength is already set
+ if r.ContentLength > 0 {
+ contentLength = r.ContentLength
+ } else {
+ contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
+ }
+ switch r.Method {
+ case "HEAD":
+ resp, err = fcgiBackend.Head(env)
+ case "GET":
+ resp, err = fcgiBackend.Get(env, r.Body, contentLength)
+ case "OPTIONS":
+ resp, err = fcgiBackend.Options(env)
+ default:
+ resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength)
+ }
+
+ // TODO:
+ return resp, err
+
+ // Stuff brought over from v1 that might not be necessary here:
+
+ // if resp != nil && resp.Body != nil {
+ // defer resp.Body.Close()
+ // }
+
+ // if err != nil {
+ // if err, ok := err.(net.Error); ok && err.Timeout() {
+ // return http.StatusGatewayTimeout, err
+ // } else if err != io.EOF {
+ // return http.StatusBadGateway, err
+ // }
+ // }
+
+ // // Write response header
+ // writeHeader(w, resp)
+
+ // // Write the response body
+ // _, err = io.Copy(w, resp.Body)
+ // if err != nil {
+ // return http.StatusBadGateway, err
+ // }
+
+ // // Log any stderr output from upstream
+ // if fcgiBackend.stderr.Len() != 0 {
+ // // Remove trailing newline, error logger already does this.
+ // err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
+ // }
+
+ // // Normally we would return the status code if it is an error status (>= 400),
+ // // however, upstream FastCGI apps don't know about our contract and have
+ // // probably already written an error page. So we just return 0, indicating
+ // // that the response body is already written. However, we do return any
+ // // error value so it can be logged.
+ // // Note that the proxy middleware works the same way, returning status=0.
+ // return 0, err
+}
+
+// buildEnv returns a set of CGI environment variables for the request.
+func (t Transport) buildEnv(r *http.Request) (map[string]string, error) {
+ var env map[string]string
+
+ // Separate remote IP and port; more lenient than net.SplitHostPort
+ var ip, port string
+ if idx := strings.LastIndex(r.RemoteAddr, ":"); idx > -1 {
+ ip = r.RemoteAddr[:idx]
+ port = r.RemoteAddr[idx+1:]
+ } else {
+ ip = r.RemoteAddr
+ }
+
+ // Remove [] from IPv6 addresses
+ ip = strings.Replace(ip, "[", "", 1)
+ ip = strings.Replace(ip, "]", "", 1)
+
+ // TODO: respect index files? or leave that to matcher/rewrite (I prefer that)?
+ fpath := r.URL.Path
+
+ // Split path in preparation for env variables.
+ // Previous canSplit checks ensure this can never be -1.
+ // TODO: I haven't brought over canSplit; make sure this doesn't break
+ splitPos := t.splitPos(fpath)
+
+ // Request has the extension; path was split successfully
+ docURI := fpath[:splitPos+len(t.SplitPath)]
+ pathInfo := fpath[splitPos+len(t.SplitPath):]
+ scriptName := fpath
+
+ // Strip PATH_INFO from SCRIPT_NAME
+ scriptName = strings.TrimSuffix(scriptName, pathInfo)
+
+ // SCRIPT_FILENAME is the absolute path of SCRIPT_NAME
+ scriptFilename := filepath.Join(t.Root, scriptName)
+
+ // Add vhost path prefix to scriptName. Otherwise, some PHP software will
+ // have difficulty discovering its URL.
+ pathPrefix, _ := r.Context().Value(caddy.CtxKey("path_prefix")).(string)
+ scriptName = path.Join(pathPrefix, scriptName)
+
+ // TODO: Disabled for now
+ // // Get the request URI from context. The context stores the original URI in case
+ // // it was changed by a middleware such as rewrite. By default, we pass the
+ // // original URI in as the value of REQUEST_URI (the user can overwrite this
+ // // if desired). Most PHP apps seem to want the original URI. Besides, this is
+ // // how nginx defaults: http://stackoverflow.com/a/12485156/1048862
+ // reqURL, _ := r.Context().Value(httpserver.OriginalURLCtxKey).(url.URL)
+
+ // // Retrieve name of remote user that was set by some downstream middleware such as basicauth.
+ // remoteUser, _ := r.Context().Value(httpserver.RemoteUserCtxKey).(string)
+
+ requestScheme := "http"
+ if r.TLS != nil {
+ requestScheme = "https"
+ }
+
+ // Some variables are unused but cleared explicitly to prevent
+ // the parent environment from interfering.
+ env = map[string]string{
+ // Variables defined in CGI 1.1 spec
+ "AUTH_TYPE": "", // Not used
+ "CONTENT_LENGTH": r.Header.Get("Content-Length"),
+ "CONTENT_TYPE": r.Header.Get("Content-Type"),
+ "GATEWAY_INTERFACE": "CGI/1.1",
+ "PATH_INFO": pathInfo,
+ "QUERY_STRING": r.URL.RawQuery,
+ "REMOTE_ADDR": ip,
+ "REMOTE_HOST": ip, // For speed, remote host lookups disabled
+ "REMOTE_PORT": port,
+ "REMOTE_IDENT": "", // Not used
+ // "REMOTE_USER": remoteUser, // TODO:
+ "REQUEST_METHOD": r.Method,
+ "REQUEST_SCHEME": requestScheme,
+ "SERVER_NAME": t.ServerName,
+ "SERVER_PORT": t.ServerPort,
+ "SERVER_PROTOCOL": r.Proto,
+ "SERVER_SOFTWARE": t.SoftwareName + "/" + t.SoftwareVersion,
+
+ // Other variables
+ // "DOCUMENT_ROOT": rule.Root,
+ "DOCUMENT_URI": docURI,
+ "HTTP_HOST": r.Host, // added here, since not always part of headers
+ // "REQUEST_URI": reqURL.RequestURI(), // TODO:
+ "SCRIPT_FILENAME": scriptFilename,
+ "SCRIPT_NAME": scriptName,
+ }
+
+ // compliance with the CGI specification requires that
+ // PATH_TRANSLATED should only exist if PATH_INFO is defined.
+ // Info: https://www.ietf.org/rfc/rfc3875 Page 14
+ if env["PATH_INFO"] != "" {
+ env["PATH_TRANSLATED"] = filepath.Join(t.Root, pathInfo) // Info: http://www.oreilly.com/openbook/cgi/ch02_04.html
+ }
+
+ // Some web apps rely on knowing HTTPS or not
+ if r.TLS != nil {
+ env["HTTPS"] = "on"
+ // and pass the protocol details in a manner compatible with apache's mod_ssl
+ // (which is why these have a SSL_ prefix and not TLS_).
+ v, ok := tlsProtocolStrings[r.TLS.Version]
+ if ok {
+ env["SSL_PROTOCOL"] = v
+ }
+ // and pass the cipher suite in a manner compatible with apache's mod_ssl
+ for k, v := range caddytls.SupportedCipherSuites {
+ if v == r.TLS.CipherSuite {
+ env["SSL_CIPHER"] = k
+ break
+ }
+ }
+ }
+
+ // Add env variables from config (with support for placeholders in values)
+ repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer)
+ for _, envVar := range t.EnvVars {
+ env[envVar[0]] = repl.ReplaceAll(envVar[1], "")
+ }
+
+ // Add all HTTP headers to env variables
+ for field, val := range r.Header {
+ header := strings.ToUpper(field)
+ header = headerNameReplacer.Replace(header)
+ env["HTTP_"+header] = strings.Join(val, ", ")
+ }
+ return env, nil
+}
+
+// splitPos returns the index where path should
+// be split based on t.SplitPath.
+func (t Transport) splitPos(path string) int {
+ // TODO:
+ // if httpserver.CaseSensitivePath {
+ // return strings.Index(path, r.SplitPath)
+ // }
+ return strings.Index(strings.ToLower(path), strings.ToLower(t.SplitPath))
+}
+
+// TODO:
+// Map of supported protocols to Apache ssl_mod format
+// Note that these are slightly different from SupportedProtocols in caddytls/config.go
+var tlsProtocolStrings = map[uint16]string{
+ tls.VersionTLS10: "TLSv1",
+ tls.VersionTLS11: "TLSv1.1",
+ tls.VersionTLS12: "TLSv1.2",
+ tls.VersionTLS13: "TLSv1.3",
+}
+
+var headerNameReplacer = strings.NewReplacer(" ", "_", "-", "_")
diff --git a/modules/caddyhttp/reverseproxy/healthchecker.go b/modules/caddyhttp/reverseproxy/healthchecker.go
deleted file mode 100755
index c557d3f..0000000
--- a/modules/caddyhttp/reverseproxy/healthchecker.go
+++ /dev/null
@@ -1,86 +0,0 @@
-// Copyright 2015 Matthew Holt and The Caddy Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package reverseproxy
-
-import (
- "net/http"
- "time"
-)
-
-// Upstream represents the interface that must be satisfied to use the healthchecker.
-type Upstream interface {
- SetHealthiness(bool)
-}
-
-// HealthChecker represents a worker that periodically evaluates if proxy upstream host is healthy.
-type HealthChecker struct {
- upstream Upstream
- Ticker *time.Ticker
- HTTPClient *http.Client
- StopChan chan bool
-}
-
-// ScheduleChecks periodically runs health checks against an upstream host.
-func (h *HealthChecker) ScheduleChecks(url string) {
- // check if a host is healthy on start vs waiting for timer
- h.upstream.SetHealthiness(h.IsHealthy(url))
- stop := make(chan bool)
- h.StopChan = stop
-
- go func() {
- for {
- select {
- case <-h.Ticker.C:
- h.upstream.SetHealthiness(h.IsHealthy(url))
- case <-stop:
- return
- }
- }
- }()
-}
-
-// Stop stops the healthchecker from makeing further requests.
-func (h *HealthChecker) Stop() {
- h.Ticker.Stop()
- close(h.StopChan)
-}
-
-// IsHealthy attempts to check if a upstream host is healthy.
-func (h *HealthChecker) IsHealthy(url string) bool {
- req, err := http.NewRequest(http.MethodGet, url, nil)
- if err != nil {
- return false
- }
-
- resp, err := h.HTTPClient.Do(req)
- if err != nil {
- return false
- }
-
- if resp.StatusCode < 200 || resp.StatusCode >= 400 {
- return false
- }
-
- return true
-}
-
-// NewHealthCheckWorker returns a new instance of a HealthChecker.
-func NewHealthCheckWorker(u Upstream, interval time.Duration, client *http.Client) *HealthChecker {
- return &HealthChecker{
- upstream: u,
- Ticker: time.NewTicker(interval),
- HTTPClient: client,
- }
-}
diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go
new file mode 100644
index 0000000..36dd776
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/httptransport.go
@@ -0,0 +1,133 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reverseproxy
+
+import (
+ "net"
+ "net/http"
+ "time"
+
+ "github.com/caddyserver/caddy/v2"
+)
+
+func init() {
+ caddy.RegisterModule(HTTPTransport{})
+}
+
+// TODO: This is the default transport, basically just http.Transport, but we define JSON struct tags...
+type HTTPTransport struct {
+ // TODO: Actually this is where the TLS config should go, technically...
+ // as well as keepalives and dial timeouts...
+ // TODO: It's possible that other transports (like fastcgi) might be
+ // able to borrow/use at least some of these config fields; if so,
+ // move them into a type called CommonTransport and embed it
+
+ TLS *TLSConfig `json:"tls,omitempty"`
+ KeepAlive *KeepAlive `json:"keep_alive,omitempty"`
+ Compression *bool `json:"compression,omitempty"`
+ MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // TODO: NOTE: we use our health check stuff to enforce max REQUESTS per host, but this is connections
+ DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`
+ FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"`
+ ResponseHeaderTimeout caddy.Duration `json:"response_header_timeout,omitempty"`
+ ExpectContinueTimeout caddy.Duration `json:"expect_continue_timeout,omitempty"`
+ MaxResponseHeaderSize int64 `json:"max_response_header_size,omitempty"`
+ WriteBufferSize int `json:"write_buffer_size,omitempty"`
+ ReadBufferSize int `json:"read_buffer_size,omitempty"`
+ // TODO: ProxyConnectHeader?
+
+ RoundTripper http.RoundTripper `json:"-"`
+}
+
+// CaddyModule returns the Caddy module information.
+func (HTTPTransport) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.transport.http",
+ New: func() caddy.Module { return new(HTTPTransport) },
+ }
+}
+
+func (h *HTTPTransport) Provision(ctx caddy.Context) error {
+ dialer := &net.Dialer{
+ Timeout: time.Duration(h.DialTimeout),
+ FallbackDelay: time.Duration(h.FallbackDelay),
+ // TODO: Resolver
+ }
+ rt := &http.Transport{
+ DialContext: dialer.DialContext,
+ MaxConnsPerHost: h.MaxConnsPerHost,
+ ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
+ ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout),
+ MaxResponseHeaderBytes: h.MaxResponseHeaderSize,
+ WriteBufferSize: h.WriteBufferSize,
+ ReadBufferSize: h.ReadBufferSize,
+ }
+
+ if h.TLS != nil {
+ rt.TLSHandshakeTimeout = time.Duration(h.TLS.HandshakeTimeout)
+ // TODO: rest of TLS config
+ }
+
+ if h.KeepAlive != nil {
+ dialer.KeepAlive = time.Duration(h.KeepAlive.ProbeInterval)
+
+ if enabled := h.KeepAlive.Enabled; enabled != nil {
+ rt.DisableKeepAlives = !*enabled
+ }
+ rt.MaxIdleConns = h.KeepAlive.MaxIdleConns
+ rt.MaxIdleConnsPerHost = h.KeepAlive.MaxIdleConnsPerHost
+ rt.IdleConnTimeout = time.Duration(h.KeepAlive.IdleConnTimeout)
+ }
+
+ if h.Compression != nil {
+ rt.DisableCompression = !*h.Compression
+ }
+
+ h.RoundTripper = rt
+
+ return nil
+}
+
+func (h HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ return h.RoundTripper.RoundTrip(req)
+}
+
+type TLSConfig struct {
+ CAPool []string `json:"ca_pool,omitempty"`
+ ClientCertificate string `json:"client_certificate,omitempty"`
+ InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"`
+ HandshakeTimeout caddy.Duration `json:"handshake_timeout,omitempty"`
+}
+
+type KeepAlive struct {
+ Enabled *bool `json:"enabled,omitempty"`
+ ProbeInterval caddy.Duration `json:"probe_interval,omitempty"`
+ MaxIdleConns int `json:"max_idle_conns,omitempty"`
+ MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"`
+ IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle
+}
+
+var (
+ defaultDialer = net.Dialer{
+ Timeout: 10 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }
+
+ // TODO: does this need to be configured to enable HTTP/2?
+ defaultTransport = &http.Transport{
+ DialContext: defaultDialer.DialContext,
+ TLSHandshakeTimeout: 5 * time.Second,
+ IdleConnTimeout: 2 * time.Minute,
+ }
+)
diff --git a/modules/caddyhttp/reverseproxy/module.go b/modules/caddyhttp/reverseproxy/module.go
deleted file mode 100755
index 21aca1d..0000000
--- a/modules/caddyhttp/reverseproxy/module.go
+++ /dev/null
@@ -1,53 +0,0 @@
-// Copyright 2015 Matthew Holt and The Caddy Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package reverseproxy
-
-import (
- "github.com/caddyserver/caddy/v2"
- "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
- "github.com/caddyserver/caddy/v2/modules/caddyhttp"
-)
-
-func init() {
- caddy.RegisterModule(new(LoadBalanced))
- httpcaddyfile.RegisterHandlerDirective("reverse_proxy", parseCaddyfile) // TODO: "proxy"?
-}
-
-// CaddyModule returns the Caddy module information.
-func (*LoadBalanced) CaddyModule() caddy.ModuleInfo {
- return caddy.ModuleInfo{
- Name: "http.handlers.reverse_proxy",
- New: func() caddy.Module { return new(LoadBalanced) },
- }
-}
-
-// parseCaddyfile sets up the handler from Caddyfile tokens. Syntax:
-//
-// proxy [<matcher>] <to>
-//
-// TODO: This needs to be finished. It definitely needs to be able to open a block...
-func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
- lb := new(LoadBalanced)
- for h.Next() {
- allTo := h.RemainingArgs()
- if len(allTo) == 0 {
- return nil, h.ArgErr()
- }
- for _, to := range allTo {
- lb.Upstreams = append(lb.Upstreams, &UpstreamConfig{Host: to})
- }
- }
- return lb, nil
-}
diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go
index 68393de..e312d71 100755..100644
--- a/modules/caddyhttp/reverseproxy/reverseproxy.go
+++ b/modules/caddyhttp/reverseproxy/reverseproxy.go
@@ -15,7 +15,9 @@
package reverseproxy
import (
+ "bytes"
"context"
+ "encoding/json"
"fmt"
"io"
"log"
@@ -24,219 +26,229 @@ import (
"net/url"
"strings"
"sync"
+ "sync/atomic"
"time"
+ "github.com/caddyserver/caddy/v2"
+ "github.com/caddyserver/caddy/v2/modules/caddyhttp"
"golang.org/x/net/http/httpguts"
)
-// ReverseProxy is an HTTP Handler that takes an incoming request and
-// sends it to another server, proxying the response back to the
-// client.
-type ReverseProxy struct {
- // Director must be a function which modifies
- // the request into a new request to be sent
- // using Transport. Its response is then copied
- // back to the original client unmodified.
- // Director must not access the provided Request
- // after returning.
- Director func(*http.Request)
-
- // The transport used to perform proxy requests.
- // If nil, http.DefaultTransport is used.
- Transport http.RoundTripper
-
- // FlushInterval specifies the flush interval
- // to flush to the client while copying the
- // response body.
- // If zero, no periodic flushing is done.
- // A negative value means to flush immediately
- // after each write to the client.
- // The FlushInterval is ignored when ReverseProxy
- // recognizes a response as a streaming response;
- // for such responses, writes are flushed to the client
- // immediately.
- FlushInterval time.Duration
-
- // ErrorLog specifies an optional logger for errors
- // that occur when attempting to proxy the request.
- // If nil, logging goes to os.Stderr via the log package's
- // standard logger.
- ErrorLog *log.Logger
-
- // BufferPool optionally specifies a buffer pool to
- // get byte slices for use by io.CopyBuffer when
- // copying HTTP response bodies.
- BufferPool BufferPool
-
- // ModifyResponse is an optional function that modifies the
- // Response from the backend. It is called if the backend
- // returns a response at all, with any HTTP status code.
- // If the backend is unreachable, the optional ErrorHandler is
- // called without any call to ModifyResponse.
- //
- // If ModifyResponse returns an error, ErrorHandler is called
- // with its error value. If ErrorHandler is nil, its default
- // implementation is used.
- ModifyResponse func(*http.Response) error
-
- // ErrorHandler is an optional function that handles errors
- // reaching the backend or errors from ModifyResponse.
- //
- // If nil, the default is to log the provided error and return
- // a 502 Status Bad Gateway response.
- ErrorHandler func(http.ResponseWriter, *http.Request, error)
-}
-
-// A BufferPool is an interface for getting and returning temporary
-// byte slices for use by io.CopyBuffer.
-type BufferPool interface {
- Get() []byte
- Put([]byte)
+func init() {
+ caddy.RegisterModule(Handler{})
}
-func singleJoiningSlash(a, b string) string {
- aslash := strings.HasSuffix(a, "/")
- bslash := strings.HasPrefix(b, "/")
- switch {
- case aslash && bslash:
- return a + b[1:]
- case !aslash && !bslash:
- return a + "/" + b
+type Handler struct {
+ TransportRaw json.RawMessage `json:"transport,omitempty"`
+ LoadBalancing *LoadBalancing `json:"load_balancing,omitempty"`
+ HealthChecks *HealthChecks `json:"health_checks,omitempty"`
+ // UpstreamStorageRaw json.RawMessage `json:"upstream_storage,omitempty"` // TODO:
+ Upstreams HostPool `json:"upstreams,omitempty"`
+
+ // UpstreamProvider UpstreamProvider `json:"-"` // TODO:
+ Transport http.RoundTripper `json:"-"`
+}
+
+// CaddyModule returns the Caddy module information.
+func (Handler) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy",
+ New: func() caddy.Module { return new(Handler) },
}
- return a + b
}
-// NewSingleHostReverseProxy returns a new ReverseProxy that routes
-// URLs to the scheme, host, and base path provided in target. If the
-// target's path is "/base" and the incoming request was for "/dir",
-// the target request will be for /base/dir.
-// NewSingleHostReverseProxy does not rewrite the Host header.
-// To rewrite Host headers, use ReverseProxy directly with a custom
-// Director policy.
-func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
- targetQuery := target.RawQuery
- director := func(req *http.Request) {
- req.URL.Scheme = target.Scheme
- req.URL.Host = target.Host
- req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
- if targetQuery == "" || req.URL.RawQuery == "" {
- req.URL.RawQuery = targetQuery + req.URL.RawQuery
- } else {
- req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
+func (h *Handler) Provision(ctx caddy.Context) error {
+ if h.TransportRaw != nil {
+ val, err := ctx.LoadModuleInline("protocol", "http.handlers.reverse_proxy.transport", h.TransportRaw)
+ if err != nil {
+ return fmt.Errorf("loading transport module: %s", err)
}
- if _, ok := req.Header["User-Agent"]; !ok {
- // explicitly disable User-Agent so it's not set to default value
- req.Header.Set("User-Agent", "")
+ h.Transport = val.(http.RoundTripper)
+ h.TransportRaw = nil // allow GC to deallocate - TODO: Does this help?
+ }
+ if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil {
+ val, err := ctx.LoadModuleInline("policy",
+ "http.handlers.reverse_proxy.selection_policies",
+ h.LoadBalancing.SelectionPolicyRaw)
+ if err != nil {
+ return fmt.Errorf("loading load balancing selection module: %s", err)
}
+ h.LoadBalancing.SelectionPolicy = val.(Selector)
+ h.LoadBalancing.SelectionPolicyRaw = nil // allow GC to deallocate - TODO: Does this help?
}
- return &ReverseProxy{Director: director}
-}
-func copyHeader(dst, src http.Header) {
- for k, vv := range src {
- for _, v := range vv {
- dst.Add(k, v)
- }
+ if h.Transport == nil {
+ h.Transport = defaultTransport
}
-}
-func cloneHeader(h http.Header) http.Header {
- h2 := make(http.Header, len(h))
- for k, vv := range h {
- vv2 := make([]string, len(vv))
- copy(vv2, vv)
- h2[k] = vv2
+ if h.LoadBalancing == nil {
+ h.LoadBalancing = new(LoadBalancing)
+ }
+ if h.LoadBalancing.SelectionPolicy == nil {
+ h.LoadBalancing.SelectionPolicy = RandomSelection{}
+ }
+ if h.LoadBalancing.TryDuration > 0 && h.LoadBalancing.TryInterval == 0 {
+ // a non-zero try_duration with a zero try_interval
+ // will always spin the CPU for try_duration if the
+ // upstream is local or low-latency; default to some
+ // sane waiting period before try attempts
+ h.LoadBalancing.TryInterval = caddy.Duration(250 * time.Millisecond)
}
- return h2
-}
-// Hop-by-hop headers. These are removed when sent to the backend.
-// As of RFC 7230, hop-by-hop headers are required to appear in the
-// Connection header field. These are the headers defined by the
-// obsoleted RFC 2616 (section 13.5.1) and are used for backward
-// compatibility.
-var hopHeaders = []string{
- "Connection",
- "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
- "Keep-Alive",
- "Proxy-Authenticate",
- "Proxy-Authorization",
- "Te", // canonicalized version of "TE"
- "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
- "Transfer-Encoding",
- "Upgrade",
-}
+ for _, upstream := range h.Upstreams {
+ // url parser requires a scheme
+ if !strings.Contains(upstream.Address, "://") {
+ upstream.Address = "http://" + upstream.Address
+ }
+ u, err := url.Parse(upstream.Address)
+ if err != nil {
+ return fmt.Errorf("invalid upstream address %s: %v", upstream.Address, err)
+ }
+ upstream.hostURL = u
+
+ // if host already exists from a current config,
+ // use that instead; otherwise, add it
+ // TODO: make hosts modular, so that their state can be distributed in enterprise for example
+ // TODO: If distributed, the pool should be stored in storage...
+ var host Host = new(upstreamHost)
+ activeHost, loaded := hosts.LoadOrStore(u.String(), host)
+ if loaded {
+ host = activeHost.(Host)
+ }
+ upstream.Host = host
+
+ // if the passive health checker has a non-zero "unhealthy
+ // request count" but the upstream has no MaxRequests set
+ // (they are the same thing, but one is a default value for
+ // for upstreams with a zero MaxRequests), copy the default
+ // value into this upstream, since the value in the upstream
+ // is what is used during availability checks
+ if h.HealthChecks != nil &&
+ h.HealthChecks.Passive != nil &&
+ h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
+ upstream.MaxRequests == 0 {
+ upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
+ }
-func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
- p.logf("http: proxy error: %v", err)
- rw.WriteHeader(http.StatusBadGateway)
-}
+ // TODO: active health checks
-func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
- if p.ErrorHandler != nil {
- return p.ErrorHandler
+ if h.HealthChecks != nil {
+ // upstreams need independent access to the passive
+ // health check policy so they can, you know, passively
+ // do health checks
+ upstream.healthCheckPolicy = h.HealthChecks.Passive
+ }
}
- return p.defaultErrorHandler
+
+ return nil
}
-// modifyResponse conditionally runs the optional ModifyResponse hook
-// and reports whether the request should proceed.
-func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
- if p.ModifyResponse == nil {
- return true
- }
- if err := p.ModifyResponse(res); err != nil {
- res.Body.Close()
- p.getErrorHandler()(rw, req, err)
- return false
+func (h *Handler) Cleanup() error {
+ // TODO: finish this up, make sure it takes care of any active health checkers or whatever
+ for _, upstream := range h.Upstreams {
+ hosts.Delete(upstream.hostURL.String())
}
- return true
+ return nil
}
-func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*http.Response, error) {
- transport := p.Transport
- if transport == nil {
- transport = http.DefaultTransport
+func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
+ // prepare the request for proxying; this is needed only once
+ err := h.prepareRequest(r)
+ if err != nil {
+ return caddyhttp.Error(http.StatusInternalServerError,
+ fmt.Errorf("preparing request for upstream round-trip: %v", err))
}
- ctx := req.Context()
- if cn, ok := rw.(http.CloseNotifier); ok {
- var cancel context.CancelFunc
- ctx, cancel = context.WithCancel(ctx)
- defer cancel()
- notifyChan := cn.CloseNotify()
- go func() {
- select {
- case <-notifyChan:
- cancel()
- case <-ctx.Done():
+ start := time.Now()
+
+ var proxyErr error
+ for {
+ // choose an available upstream
+ upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, r)
+ if upstream == nil {
+ if proxyErr == nil {
+ proxyErr = fmt.Errorf("no available upstreams")
+ }
+ if !h.tryAgain(start, proxyErr) {
+ break
}
- }()
+ continue
+ }
+
+ // proxy the request to that upstream
+ proxyErr = h.reverseProxy(w, r, upstream)
+ if proxyErr == nil {
+ return nil
+ }
+
+ // remember this failure (if enabled)
+ h.countFailure(upstream)
+
+ // if we've tried long enough, break
+ if !h.tryAgain(start, proxyErr) {
+ break
+ }
}
- outreq := req.WithContext(ctx) // includes shallow copies of maps, but okay
+ return caddyhttp.Error(http.StatusBadGateway, proxyErr)
+}
+
+// prepareRequest modifies req so that it is ready to be proxied,
+// except for directing to a specific upstream. This method mutates
+// headers and other necessary properties of the request and should
+// be done just once (before proxying) regardless of proxy retries.
+// This assumes that no mutations of the request are performed
+// by h during or after proxying.
+func (h Handler) prepareRequest(req *http.Request) error {
+ // ctx := req.Context()
+ // TODO: do we need to support CloseNotifier? It was deprecated years ago.
+ // All this does is wrap CloseNotify with context cancel, for those responsewriters
+ // which didn't support context, but all the ones we'd use should nowadays, right?
+ // if cn, ok := rw.(http.CloseNotifier); ok {
+ // var cancel context.CancelFunc
+ // ctx, cancel = context.WithCancel(ctx)
+ // defer cancel()
+ // notifyChan := cn.CloseNotify()
+ // go func() {
+ // select {
+ // case <-notifyChan:
+ // cancel()
+ // case <-ctx.Done():
+ // }
+ // }()
+ // }
+
+ // TODO: do we need to call WithContext, since we won't be changing req.Context() above if we remove the CloseNotifier stuff?
+ // TODO: (This is where references to req were originally "outreq", a shallow clone, which I think is unnecessary in our case)
+ // req = req.WithContext(ctx) // includes shallow copies of maps, but okay
if req.ContentLength == 0 {
- outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
+ req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries
}
- outreq.Header = cloneHeader(req.Header)
+ // TODO: is this needed?
+ // req.Header = cloneHeader(req.Header)
- p.Director(outreq)
- outreq.Close = false
+ req.Close = false
- reqUpType := upgradeType(outreq.Header)
- removeConnectionHeaders(outreq.Header)
+ // if User-Agent is not set by client, then explicitly
+ // disable it so it's not set to default value by std lib
+ if _, ok := req.Header["User-Agent"]; !ok {
+ req.Header.Set("User-Agent", "")
+ }
+
+ reqUpType := upgradeType(req.Header)
+ removeConnectionHeaders(req.Header)
// Remove hop-by-hop headers to the backend. Especially
// important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us.
for _, h := range hopHeaders {
- hv := outreq.Header.Get(h)
+ hv := req.Header.Get(h)
if hv == "" {
continue
}
if h == "Te" && hv == "trailers" {
- // Issue 21096: tell backend applications that
+ // Issue golang/go#21096: tell backend applications that
// care about trailer support that we support
// trailers. (We do, but we don't go out of
// our way to advertise that unless the
@@ -244,40 +256,65 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht
// worth mentioning)
continue
}
- outreq.Header.Del(h)
+ req.Header.Del(h)
}
// After stripping all the hop-by-hop connection headers above, add back any
// necessary for protocol upgrades, such as for websockets.
if reqUpType != "" {
- outreq.Header.Set("Connection", "Upgrade")
- outreq.Header.Set("Upgrade", reqUpType)
+ req.Header.Set("Connection", "Upgrade")
+ req.Header.Set("Upgrade", reqUpType)
}
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
- if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
+ if prior, ok := req.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
- outreq.Header.Set("X-Forwarded-For", clientIP)
+ req.Header.Set("X-Forwarded-For", clientIP)
}
- res, err := transport.RoundTrip(outreq)
+ return nil
+}
+
+// TODO:
+// this code is the entry point to what was borrowed from the net/http/httputil package in the standard library.
+func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, upstream *Upstream) error {
+ // TODO: count this active request
+
+ // point the request to this upstream
+ h.directRequest(req, upstream)
+
+ // do the round-trip
+ start := time.Now()
+ res, err := h.Transport.RoundTrip(req)
+ latency := time.Since(start)
if err != nil {
- p.getErrorHandler()(rw, outreq, err)
- return nil, err
+ return err
}
- // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
- if res.StatusCode == http.StatusSwitchingProtocols {
- if !p.modifyResponse(rw, res, outreq) {
- return res, nil
+ // perform passive health checks (if enabled)
+ if h.HealthChecks != nil && h.HealthChecks.Passive != nil {
+ // strike if the status code matches one that is "bad"
+ for _, badStatus := range h.HealthChecks.Passive.UnhealthyStatus {
+ if caddyhttp.StatusCodeMatches(res.StatusCode, badStatus) {
+ h.countFailure(upstream)
+ }
}
- p.handleUpgradeResponse(rw, outreq, res)
- return res, nil
+ // strike if the roundtrip took too long
+ if h.HealthChecks.Passive.UnhealthyLatency > 0 &&
+ latency >= time.Duration(h.HealthChecks.Passive.UnhealthyLatency) {
+ h.countFailure(upstream)
+ }
+ }
+
+ // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
+ if res.StatusCode == http.StatusSwitchingProtocols {
+ h.handleUpgradeResponse(rw, req, res)
+ return nil
}
removeConnectionHeaders(res.Header)
@@ -286,10 +323,6 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht
res.Header.Del(h)
}
- if !p.modifyResponse(rw, res, outreq) {
- return res, nil
- }
-
copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response,
@@ -305,15 +338,16 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht
rw.WriteHeader(res.StatusCode)
- err = p.copyResponse(rw, res.Body, p.flushInterval(req, res))
+ err = h.copyResponse(rw, res.Body, h.flushInterval(req, res))
if err != nil {
defer res.Body.Close()
// Since we're streaming the response, if we run into an error all we can do
- // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler
+ // is abort the request. Issue golang/go#23643: ReverseProxy should use ErrAbortHandler
// on read error while copying body.
+ // TODO: Look into whether we want to panic at all in our case...
if !shouldPanicOnCopyError(req) {
- p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
- return nil, err
+ // p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
+ return err
}
panic(http.ErrAbortHandler)
@@ -331,7 +365,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht
if len(res.Trailer) == announcedTrailers {
copyHeader(rw.Header(), res.Trailer)
- return res, nil
+ return nil
}
for k, vv := range res.Trailer {
@@ -341,46 +375,90 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht
}
}
- return res, nil
+ return nil
}
-var inOurTests bool // whether we're in our own tests
-
-// shouldPanicOnCopyError reports whether the reverse proxy should
-// panic with http.ErrAbortHandler. This is the right thing to do by
-// default, but Go 1.10 and earlier did not, so existing unit tests
-// weren't expecting panics. Only panic in our own tests, or when
-// running under the HTTP server.
-func shouldPanicOnCopyError(req *http.Request) bool {
- if inOurTests {
- // Our tests know to handle this panic.
- return true
+// tryAgain takes the time that the handler was initially invoked
+// as well as any error currently obtained and returns true if
+// another attempt should be made at proxying the request. If
+// true is returned, it has already blocked long enough before
+// the next retry (i.e. no more sleeping is needed). If false is
+// returned, the handler should stop trying to proxy the request.
+func (h Handler) tryAgain(start time.Time, proxyErr error) bool {
+ // if downstream has canceled the request, break
+ if proxyErr == context.Canceled {
+ return false
}
- if req.Context().Value(http.ServerContextKey) != nil {
- // We seem to be running under an HTTP server, so
- // it'll recover the panic.
- return true
+ // if we've tried long enough, break
+ if time.Since(start) >= time.Duration(h.LoadBalancing.TryDuration) {
+ return false
}
- // Otherwise act like Go 1.10 and earlier to not break
- // existing tests.
- return false
+ // otherwise, wait and try the next available host
+ time.Sleep(time.Duration(h.LoadBalancing.TryInterval))
+ return true
}
-// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h.
-// See RFC 7230, section 6.1
-func removeConnectionHeaders(h http.Header) {
- if c := h.Get("Connection"); c != "" {
- for _, f := range strings.Split(c, ",") {
- if f = strings.TrimSpace(f); f != "" {
- h.Del(f)
- }
- }
+// directRequest modifies only req.URL so that it points to the
+// given upstream host. It must modify ONLY the request URL.
+func (h Handler) directRequest(req *http.Request, upstream *Upstream) {
+ target := upstream.hostURL
+ req.URL.Scheme = target.Scheme
+ req.URL.Host = target.Host
+ req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) // TODO: This might be a bug (if any part of the path was augmented from a previously-tried upstream; need to start from clean original path of request, same for query string!)
+ if target.RawQuery == "" || req.URL.RawQuery == "" {
+ req.URL.RawQuery = target.RawQuery + req.URL.RawQuery
+ } else {
+ req.URL.RawQuery = target.RawQuery + "&" + req.URL.RawQuery
}
}
+func (h Handler) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
+ reqUpType := upgradeType(req.Header)
+ resUpType := upgradeType(res.Header)
+ if reqUpType != resUpType {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
+ return
+ }
+
+ copyHeader(res.Header, rw.Header())
+
+ hj, ok := rw.(http.Hijacker)
+ if !ok {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
+ return
+ }
+ backConn, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
+ return
+ }
+ defer backConn.Close()
+ conn, brw, err := hj.Hijack()
+ if err != nil {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
+ return
+ }
+ defer conn.Close()
+ res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
+ if err := res.Write(brw); err != nil {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
+ return
+ }
+ if err := brw.Flush(); err != nil {
+ // p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
+ return
+ }
+ errc := make(chan error, 1)
+ spc := switchProtocolCopier{user: conn, backend: backConn}
+ go spc.copyToBackend(errc)
+ go spc.copyFromBackend(errc)
+ <-errc
+ return
+}
+
// flushInterval returns the p.FlushInterval value, conditionally
// overriding its value for a specific request/response.
-func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration {
+func (h Handler) flushInterval(req *http.Request, res *http.Response) time.Duration {
resCT := res.Header.Get("Content-Type")
// For Server-Sent Events responses, flush immediately.
@@ -390,10 +468,11 @@ func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time
}
// TODO: more specific cases? e.g. res.ContentLength == -1?
- return p.FlushInterval
+ // return h.FlushInterval
+ return 0
}
-func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
+func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
if flushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
@@ -405,18 +484,25 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval
}
}
+ // TODO: Figure out how we want to do this... using custom buffer pool type seems unnecessary
+ // or maybe it is, depending on how we want to handle errors,
+ // see: https://github.com/golang/go/issues/21814
+ // buf := bufPool.Get().(*bytes.Buffer)
+ // buf.Reset()
+ // defer bufPool.Put(buf)
+ // _, err := io.CopyBuffer(dst, src, )
var buf []byte
- if p.BufferPool != nil {
- buf = p.BufferPool.Get()
- defer p.BufferPool.Put(buf)
- }
- _, err := p.copyBuffer(dst, src, buf)
+ // if h.BufferPool != nil {
+ // buf = h.BufferPool.Get()
+ // defer h.BufferPool.Put(buf)
+ // }
+ _, err := h.copyBuffer(dst, src, buf)
return err
}
// copyBuffer returns any write errors or non-EOF read errors, and the amount
// of bytes written.
-func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
+func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
if len(buf) == 0 {
buf = make([]byte, 32*1024)
}
@@ -424,7 +510,13 @@ func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int
for {
nr, rerr := src.Read(buf)
if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
- p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
+ // TODO: this could be useful to know (indeed, it revealed an error in our
+ // fastcgi PoC earlier; but it's this single error report here that necessitates
+ // a function separate from io.CopyBuffer, since io.CopyBuffer does not distinguish
+ // between read or write errors; in a reverse proxy situation, write errors are not
+ // something we need to report to the client, but read errors are a problem on our
+ // end for sure. so we need to decide what we want.)
+ // p.logf("copyBuffer: ReverseProxy read error during body copy: %v", rerr)
}
if nr > 0 {
nw, werr := dst.Write(buf[:nr])
@@ -447,12 +539,36 @@ func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int
}
}
-func (p *ReverseProxy) logf(format string, args ...interface{}) {
- if p.ErrorLog != nil {
- p.ErrorLog.Printf(format, args...)
- } else {
- log.Printf(format, args...)
+// countFailure remembers 1 failure for upstream for the
+// configured duration. If passive health checks are
+// disabled or failure expiry is 0, this is a no-op.
+func (h Handler) countFailure(upstream *Upstream) {
+ // only count failures if passive health checking is enabled
+ // and if failures are configured have a non-zero expiry
+ if h.HealthChecks == nil || h.HealthChecks.Passive == nil {
+ return
+ }
+ failDuration := time.Duration(h.HealthChecks.Passive.FailDuration)
+ if failDuration == 0 {
+ return
+ }
+
+ // count failure immediately
+ err := upstream.Host.CountFail(1)
+ if err != nil {
+ log.Printf("[ERROR] proxy: upstream %s: counting failure: %v",
+ upstream.hostURL, err)
}
+
+ // forget it later
+ go func(host Host, failDuration time.Duration) {
+ time.Sleep(failDuration)
+ err := host.CountFail(-1)
+ if err != nil {
+ log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v",
+ upstream.hostURL, err)
+ }
+ }(upstream.Host, failDuration)
}
type writeFlusher interface {
@@ -508,6 +624,61 @@ func (m *maxLatencyWriter) stop() {
}
}
+// switchProtocolCopier exists so goroutines proxying data back and
+// forth have nice names in stacks.
+type switchProtocolCopier struct {
+ user, backend io.ReadWriter
+}
+
+func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
+ _, err := io.Copy(c.user, c.backend)
+ errc <- err
+}
+
+func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
+ _, err := io.Copy(c.backend, c.user)
+ errc <- err
+}
+
+// shouldPanicOnCopyError reports whether the reverse proxy should
+// panic with http.ErrAbortHandler. This is the right thing to do by
+// default, but Go 1.10 and earlier did not, so existing unit tests
+// weren't expecting panics. Only panic in our own tests, or when
+// running under the HTTP server.
+// TODO: I don't know if we want this at all...
+func shouldPanicOnCopyError(req *http.Request) bool {
+ // if inOurTests {
+ // // Our tests know to handle this panic.
+ // return true
+ // }
+ if req.Context().Value(http.ServerContextKey) != nil {
+ // We seem to be running under an HTTP server, so
+ // it'll recover the panic.
+ return true
+ }
+ // Otherwise act like Go 1.10 and earlier to not break
+ // existing tests.
+ return false
+}
+
+func copyHeader(dst, src http.Header) {
+ for k, vv := range src {
+ for _, v := range vv {
+ dst.Add(k, v)
+ }
+ }
+}
+
+func cloneHeader(h http.Header) http.Header {
+ h2 := make(http.Header, len(h))
+ for k, vv := range h {
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ h2[k] = vv2
+ }
+ return h2
+}
+
func upgradeType(h http.Header) string {
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
return ""
@@ -515,62 +686,177 @@ func upgradeType(h http.Header) string {
return strings.ToLower(h.Get("Upgrade"))
}
-func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
- reqUpType := upgradeType(req.Header)
- resUpType := upgradeType(res.Header)
- if reqUpType != resUpType {
- p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
- return
+func singleJoiningSlash(a, b string) string {
+ aslash := strings.HasSuffix(a, "/")
+ bslash := strings.HasPrefix(b, "/")
+ switch {
+ case aslash && bslash:
+ return a + b[1:]
+ case !aslash && !bslash:
+ return a + "/" + b
}
+ return a + b
+}
- copyHeader(res.Header, rw.Header())
-
- hj, ok := rw.(http.Hijacker)
- if !ok {
- p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
- return
- }
- backConn, ok := res.Body.(io.ReadWriteCloser)
- if !ok {
- p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
- return
+// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h.
+// See RFC 7230, section 6.1
+func removeConnectionHeaders(h http.Header) {
+ if c := h.Get("Connection"); c != "" {
+ for _, f := range strings.Split(c, ",") {
+ if f = strings.TrimSpace(f); f != "" {
+ h.Del(f)
+ }
+ }
}
- defer backConn.Close()
- conn, brw, err := hj.Hijack()
- if err != nil {
- p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
- return
+}
+
+type LoadBalancing struct {
+ SelectionPolicyRaw json.RawMessage `json:"selection_policy,omitempty"`
+ TryDuration caddy.Duration `json:"try_duration,omitempty"`
+ TryInterval caddy.Duration `json:"try_interval,omitempty"`
+
+ SelectionPolicy Selector `json:"-"`
+}
+
+type Selector interface {
+ Select(HostPool, *http.Request) *Upstream
+}
+
+type HealthChecks struct {
+ Active *ActiveHealthChecks `json:"active,omitempty"`
+ Passive *PassiveHealthChecks `json:"passive,omitempty"`
+}
+
+type ActiveHealthChecks struct {
+ Path string `json:"path,omitempty"`
+ Port int `json:"port,omitempty"`
+ Interval caddy.Duration `json:"interval,omitempty"`
+ Timeout caddy.Duration `json:"timeout,omitempty"`
+ MaxSize int `json:"max_size,omitempty"`
+ ExpectStatus int `json:"expect_status,omitempty"`
+ ExpectBody string `json:"expect_body,omitempty"`
+}
+
+type PassiveHealthChecks struct {
+ MaxFails int `json:"max_fails,omitempty"`
+ FailDuration caddy.Duration `json:"fail_duration,omitempty"`
+ UnhealthyRequestCount int `json:"unhealthy_request_count,omitempty"`
+ UnhealthyStatus []int `json:"unhealthy_status,omitempty"`
+ UnhealthyLatency caddy.Duration `json:"unhealthy_latency,omitempty"`
+}
+
+// Hop-by-hop headers. These are removed when sent to the backend.
+// As of RFC 7230, hop-by-hop headers are required to appear in the
+// Connection header field. These are the headers defined by the
+// obsoleted RFC 2616 (section 13.5.1) and are used for backward
+// compatibility.
+var hopHeaders = []string{
+ "Connection",
+ "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
+ "Keep-Alive",
+ "Proxy-Authenticate",
+ "Proxy-Authorization",
+ "Te", // canonicalized version of "TE"
+ "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
+ "Transfer-Encoding",
+ "Upgrade",
+}
+
+var bufPool = sync.Pool{
+ New: func() interface{} {
+ return new(bytes.Buffer)
+ },
+}
+
+//////////////////////////////////
+// TODO:
+
+type Host interface {
+ NumRequests() int
+ Fails() int
+ Unhealthy() bool
+
+ CountRequest(int) error
+ CountFail(int) error
+}
+
+type HostPool []*Upstream
+
+type upstreamHost struct {
+ numRequests int64 // must be first field to be 64-bit aligned on 32-bit systems (see https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
+ fails int64
+ unhealthy int32
+}
+
+func (uh upstreamHost) NumRequests() int {
+ return int(atomic.LoadInt64(&uh.numRequests))
+}
+func (uh upstreamHost) Fails() int {
+ return int(atomic.LoadInt64(&uh.fails))
+}
+func (uh upstreamHost) Unhealthy() bool {
+ return atomic.LoadInt32(&uh.unhealthy) == 1
+}
+func (uh *upstreamHost) CountRequest(delta int) error {
+ result := atomic.AddInt64(&uh.numRequests, int64(delta))
+ if result < 0 {
+ return fmt.Errorf("count below 0: %d", result)
}
- defer conn.Close()
- res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
- if err := res.Write(brw); err != nil {
- p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
- return
+ return nil
+}
+func (uh *upstreamHost) CountFail(delta int) error {
+ result := atomic.AddInt64(&uh.fails, int64(delta))
+ if result < 0 {
+ return fmt.Errorf("count below 0: %d", result)
}
- if err := brw.Flush(); err != nil {
- p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
- return
+ return nil
+}
+
+type Upstream struct {
+ Host `json:"-"`
+
+ Address string `json:"address,omitempty"`
+ MaxRequests int `json:"max_requests,omitempty"`
+
+ // TODO: This could be really cool, to say that requests with
+ // certain headers or from certain IPs always go to this upstream
+ // HeaderAffinity string
+ // IPAffinity string
+
+ healthCheckPolicy *PassiveHealthChecks
+
+ hostURL *url.URL
+}
+
+func (u Upstream) Available() bool {
+ return u.Healthy() && !u.Full()
+}
+
+func (u Upstream) Healthy() bool {
+ healthy := !u.Host.Unhealthy()
+ if healthy && u.healthCheckPolicy != nil {
+ healthy = u.Host.Fails() < u.healthCheckPolicy.MaxFails
}
- errc := make(chan error, 1)
- spc := switchProtocolCopier{user: conn, backend: backConn}
- go spc.copyToBackend(errc)
- go spc.copyFromBackend(errc)
- <-errc
- return
+ return healthy
}
-// switchProtocolCopier exists so goroutines proxying data back and
-// forth have nice names in stacks.
-type switchProtocolCopier struct {
- user, backend io.ReadWriter
+func (u Upstream) Full() bool {
+ return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests
}
-func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
- _, err := io.Copy(c.user, c.backend)
- errc <- err
+func (u Upstream) URL() *url.URL {
+ return u.hostURL
}
-func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
- _, err := io.Copy(c.backend, c.user)
- errc <- err
+var hosts = caddy.NewUsagePool()
+
+// TODO: ...
+type UpstreamProvider interface {
}
+
+// Interface guards
+var (
+ _ caddyhttp.MiddlewareHandler = (*Handler)(nil)
+ _ caddy.Provisioner = (*Handler)(nil)
+ _ caddy.CleanerUpper = (*Handler)(nil)
+)
diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go
new file mode 100644
index 0000000..e0518c9
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go
@@ -0,0 +1,351 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reverseproxy
+
+import (
+ "fmt"
+ "hash/fnv"
+ weakrand "math/rand"
+ "net"
+ "net/http"
+ "sync/atomic"
+ "time"
+
+ "github.com/caddyserver/caddy/v2"
+)
+
+func init() {
+ caddy.RegisterModule(RandomSelection{})
+ caddy.RegisterModule(RandomChoiceSelection{})
+ caddy.RegisterModule(LeastConnSelection{})
+ caddy.RegisterModule(RoundRobinSelection{})
+ caddy.RegisterModule(FirstSelection{})
+ caddy.RegisterModule(IPHashSelection{})
+ caddy.RegisterModule(URIHashSelection{})
+ caddy.RegisterModule(HeaderHashSelection{})
+
+ weakrand.Seed(time.Now().UTC().UnixNano())
+}
+
+// RandomSelection is a policy that selects
+// an available host at random.
+type RandomSelection struct{}
+
+// CaddyModule returns the Caddy module information.
+func (RandomSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.random",
+ New: func() caddy.Module { return new(RandomSelection) },
+ }
+}
+
+// Select returns an available host, if any.
+func (r RandomSelection) Select(pool HostPool, request *http.Request) *Upstream {
+ // use reservoir sampling because the number of available
+ // hosts isn't known: https://en.wikipedia.org/wiki/Reservoir_sampling
+ var randomHost *Upstream
+ var count int
+ for _, upstream := range pool {
+ if !upstream.Available() {
+ continue
+ }
+ // (n % 1 == 0) holds for all n, therefore a
+ // upstream will always be chosen if there is at
+ // least one available
+ count++
+ if (weakrand.Int() % count) == 0 {
+ randomHost = upstream
+ }
+ }
+ return randomHost
+}
+
+// RandomChoiceSelection is a policy that selects
+// two or more available hosts at random, then
+// chooses the one with the least load.
+type RandomChoiceSelection struct {
+ Choose int `json:"choose,omitempty"`
+}
+
+// CaddyModule returns the Caddy module information.
+func (RandomChoiceSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.random_choice",
+ New: func() caddy.Module { return new(RandomChoiceSelection) },
+ }
+}
+
+func (r *RandomChoiceSelection) Provision(ctx caddy.Context) error {
+ if r.Choose == 0 {
+ r.Choose = 2
+ }
+ return nil
+}
+
+func (r RandomChoiceSelection) Validate() error {
+ if r.Choose < 2 {
+ return fmt.Errorf("choose must be at least 2")
+ }
+ return nil
+}
+
+// Select returns an available host, if any.
+func (r RandomChoiceSelection) Select(pool HostPool, _ *http.Request) *Upstream {
+ k := r.Choose
+ if k > len(pool) {
+ k = len(pool)
+ }
+ choices := make([]*Upstream, k)
+ for i, upstream := range pool {
+ if !upstream.Available() {
+ continue
+ }
+ j := weakrand.Intn(i)
+ if j < k {
+ choices[j] = upstream
+ }
+ }
+ return leastRequests(choices)
+}
+
+// LeastConnSelection is a policy that selects the
+// host with the least active requests. If multiple
+// hosts have the same fewest number, one is chosen
+// randomly. The term "conn" or "connection" is used
+// in this policy name due to its similar meaning in
+// other software, but our load balancer actually
+// counts active requests rather than connections,
+// since these days requests are multiplexed onto
+// shared connections.
+type LeastConnSelection struct{}
+
+// CaddyModule returns the Caddy module information.
+func (LeastConnSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.least_conn",
+ New: func() caddy.Module { return new(LeastConnSelection) },
+ }
+}
+
+// Select selects the up host with the least number of connections in the
+// pool. If more than one host has the same least number of connections,
+// one of the hosts is chosen at random.
+func (LeastConnSelection) Select(pool HostPool, _ *http.Request) *Upstream {
+ var bestHost *Upstream
+ var count int
+ var leastReqs int
+
+ for _, host := range pool {
+ if !host.Available() {
+ continue
+ }
+ numReqs := host.NumRequests()
+ if numReqs < leastReqs {
+ leastReqs = numReqs
+ count = 0
+ }
+
+ // among hosts with same least connections, perform a reservoir
+ // sample: https://en.wikipedia.org/wiki/Reservoir_sampling
+ if numReqs == leastReqs {
+ count++
+ if (weakrand.Int() % count) == 0 {
+ bestHost = host
+ }
+ }
+ }
+
+ return bestHost
+}
+
+// RoundRobinSelection is a policy that selects
+// a host based on round-robin ordering.
+type RoundRobinSelection struct {
+ robin uint32
+}
+
+// CaddyModule returns the Caddy module information.
+func (RoundRobinSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.round_robin",
+ New: func() caddy.Module { return new(RoundRobinSelection) },
+ }
+}
+
+// Select returns an available host, if any.
+func (r *RoundRobinSelection) Select(pool HostPool, _ *http.Request) *Upstream {
+ n := uint32(len(pool))
+ if n == 0 {
+ return nil
+ }
+ for i := uint32(0); i < n; i++ {
+ atomic.AddUint32(&r.robin, 1)
+ host := pool[r.robin%n]
+ if host.Available() {
+ return host
+ }
+ }
+ return nil
+}
+
+// FirstSelection is a policy that selects
+// the first available host.
+type FirstSelection struct{}
+
+// CaddyModule returns the Caddy module information.
+func (FirstSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.first",
+ New: func() caddy.Module { return new(FirstSelection) },
+ }
+}
+
+// Select returns an available host, if any.
+func (FirstSelection) Select(pool HostPool, _ *http.Request) *Upstream {
+ for _, host := range pool {
+ if host.Available() {
+ return host
+ }
+ }
+ return nil
+}
+
+// IPHashSelection is a policy that selects a host
+// based on hashing the remote IP of the request.
+type IPHashSelection struct{}
+
+// CaddyModule returns the Caddy module information.
+func (IPHashSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.ip_hash",
+ New: func() caddy.Module { return new(IPHashSelection) },
+ }
+}
+
+// Select returns an available host, if any.
+func (IPHashSelection) Select(pool HostPool, req *http.Request) *Upstream {
+ clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
+ if err != nil {
+ clientIP = req.RemoteAddr
+ }
+ return hostByHashing(pool, clientIP)
+}
+
+// URIHashSelection is a policy that selects a
+// host by hashing the request URI.
+type URIHashSelection struct{}
+
+// CaddyModule returns the Caddy module information.
+func (URIHashSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.uri_hash",
+ New: func() caddy.Module { return new(URIHashSelection) },
+ }
+}
+
+// Select returns an available host, if any.
+func (URIHashSelection) Select(pool HostPool, req *http.Request) *Upstream {
+ return hostByHashing(pool, req.RequestURI)
+}
+
+// HeaderHashSelection is a policy that selects
+// a host based on a given request header.
+type HeaderHashSelection struct {
+ Field string `json:"field,omitempty"`
+}
+
+// CaddyModule returns the Caddy module information.
+func (HeaderHashSelection) CaddyModule() caddy.ModuleInfo {
+ return caddy.ModuleInfo{
+ Name: "http.handlers.reverse_proxy.selection_policies.header",
+ New: func() caddy.Module { return new(HeaderHashSelection) },
+ }
+}
+
+// Select returns an available host, if any.
+func (s HeaderHashSelection) Select(pool HostPool, req *http.Request) *Upstream {
+ if s.Field == "" {
+ return nil
+ }
+ val := req.Header.Get(s.Field)
+ if val == "" {
+ return RandomSelection{}.Select(pool, req)
+ }
+ return hostByHashing(pool, val)
+}
+
+// leastRequests returns the host with the
+// least number of active requests to it.
+// If more than one host has the same
+// least number of active requests, then
+// one of those is chosen at random.
+func leastRequests(upstreams []*Upstream) *Upstream {
+ if len(upstreams) == 0 {
+ return nil
+ }
+ var best []*Upstream
+ var bestReqs int
+ for _, upstream := range upstreams {
+ reqs := upstream.NumRequests()
+ if reqs == 0 {
+ return upstream
+ }
+ if reqs <= bestReqs {
+ bestReqs = reqs
+ best = append(best, upstream)
+ }
+ }
+ return best[weakrand.Intn(len(best))]
+}
+
+// hostByHashing returns an available host
+// from pool based on a hashable string s.
+func hostByHashing(pool []*Upstream, s string) *Upstream {
+ poolLen := uint32(len(pool))
+ if poolLen == 0 {
+ return nil
+ }
+ index := hash(s) % poolLen
+ for i := uint32(0); i < poolLen; i++ {
+ index += i
+ upstream := pool[index%poolLen]
+ if upstream.Available() {
+ return upstream
+ }
+ }
+ return nil
+}
+
+// hash calculates a fast hash based on s.
+func hash(s string) uint32 {
+ h := fnv.New32a()
+ h.Write([]byte(s))
+ return h.Sum32()
+}
+
+// Interface guards
+var (
+ _ Selector = (*RandomSelection)(nil)
+ _ Selector = (*RandomChoiceSelection)(nil)
+ _ Selector = (*LeastConnSelection)(nil)
+ _ Selector = (*RoundRobinSelection)(nil)
+ _ Selector = (*FirstSelection)(nil)
+ _ Selector = (*IPHashSelection)(nil)
+ _ Selector = (*URIHashSelection)(nil)
+ _ Selector = (*HeaderHashSelection)(nil)
+
+ _ caddy.Validator = (*RandomChoiceSelection)(nil)
+ _ caddy.Provisioner = (*RandomChoiceSelection)(nil)
+)
diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
new file mode 100644
index 0000000..8006fb1
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
@@ -0,0 +1,363 @@
+// Copyright 2015 Matthew Holt and The Caddy Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reverseproxy
+
+// TODO: finish migrating these
+
+// import (
+// "net/http"
+// "net/http/httptest"
+// "os"
+// "testing"
+// )
+
+// var workableServer *httptest.Server
+
+// func TestMain(m *testing.M) {
+// workableServer = httptest.NewServer(http.HandlerFunc(
+// func(w http.ResponseWriter, r *http.Request) {
+// // do nothing
+// }))
+// r := m.Run()
+// workableServer.Close()
+// os.Exit(r)
+// }
+
+// type customPolicy struct{}
+
+// func (customPolicy) Select(pool HostPool, _ *http.Request) Host {
+// return pool[0]
+// }
+
+// func testPool() HostPool {
+// pool := []*UpstreamHost{
+// {
+// Name: workableServer.URL, // this should resolve (healthcheck test)
+// },
+// {
+// Name: "http://localhost:99998", // this shouldn't
+// },
+// {
+// Name: "http://C",
+// },
+// }
+// return HostPool(pool)
+// }
+
+// func TestRoundRobinPolicy(t *testing.T) {
+// pool := testPool()
+// rrPolicy := &RoundRobin{}
+// request, _ := http.NewRequest("GET", "/", nil)
+
+// h := rrPolicy.Select(pool, request)
+// // First selected host is 1, because counter starts at 0
+// // and increments before host is selected
+// if h != pool[1] {
+// t.Error("Expected first round robin host to be second host in the pool.")
+// }
+// h = rrPolicy.Select(pool, request)
+// if h != pool[2] {
+// t.Error("Expected second round robin host to be third host in the pool.")
+// }
+// h = rrPolicy.Select(pool, request)
+// if h != pool[0] {
+// t.Error("Expected third round robin host to be first host in the pool.")
+// }
+// // mark host as down
+// pool[1].Unhealthy = 1
+// h = rrPolicy.Select(pool, request)
+// if h != pool[2] {
+// t.Error("Expected to skip down host.")
+// }
+// // mark host as up
+// pool[1].Unhealthy = 0
+
+// h = rrPolicy.Select(pool, request)
+// if h == pool[2] {
+// t.Error("Expected to balance evenly among healthy hosts")
+// }
+// // mark host as full
+// pool[1].Conns = 1
+// pool[1].MaxConns = 1
+// h = rrPolicy.Select(pool, request)
+// if h != pool[2] {
+// t.Error("Expected to skip full host.")
+// }
+// }
+
+// func TestLeastConnPolicy(t *testing.T) {
+// pool := testPool()
+// lcPolicy := &LeastConn{}
+// request, _ := http.NewRequest("GET", "/", nil)
+
+// pool[0].Conns = 10
+// pool[1].Conns = 10
+// h := lcPolicy.Select(pool, request)
+// if h != pool[2] {
+// t.Error("Expected least connection host to be third host.")
+// }
+// pool[2].Conns = 100
+// h = lcPolicy.Select(pool, request)
+// if h != pool[0] && h != pool[1] {
+// t.Error("Expected least connection host to be first or second host.")
+// }
+// }
+
+// func TestCustomPolicy(t *testing.T) {
+// pool := testPool()
+// customPolicy := &customPolicy{}
+// request, _ := http.NewRequest("GET", "/", nil)
+
+// h := customPolicy.Select(pool, request)
+// if h != pool[0] {
+// t.Error("Expected custom policy host to be the first host.")
+// }
+// }
+
+// func TestIPHashPolicy(t *testing.T) {
+// pool := testPool()
+// ipHash := &IPHash{}
+// request, _ := http.NewRequest("GET", "/", nil)
+// // We should be able to predict where every request is routed.
+// request.RemoteAddr = "172.0.0.1:80"
+// h := ipHash.Select(pool, request)
+// if h != pool[1] {
+// t.Error("Expected ip hash policy host to be the second host.")
+// }
+// request.RemoteAddr = "172.0.0.2:80"
+// h = ipHash.Select(pool, request)
+// if h != pool[1] {
+// t.Error("Expected ip hash policy host to be the second host.")
+// }
+// request.RemoteAddr = "172.0.0.3:80"
+// h = ipHash.Select(pool, request)
+// if h != pool[2] {
+// t.Error("Expected ip hash policy host to be the third host.")
+// }
+// request.RemoteAddr = "172.0.0.4:80"
+// h = ipHash.Select(pool, request)
+// if h != pool[1] {
+// t.Error("Expected ip hash policy host to be the second host.")
+// }
+
+// // we should get the same results without a port
+// request.RemoteAddr = "172.0.0.1"
+// h = ipHash.Select(pool, request)
+// if h != pool[1] {
+// t.Error("Expected ip hash policy host to be the second host.")
+// }
+// request.RemoteAddr = "172.0.0.2"
+// h = ipHash.Select(pool, request)
+// if h != pool[1] {
+// t.Error("Expected ip hash policy host to be the second host.")
+// }
+// request.RemoteAddr = "172.0.0.3"
+// h = ipHash.Select(pool, request)
+// if h != pool[2] {
+// t.Error("Expected ip hash policy host to be the third host.")
+// }
+// request.RemoteAddr = "172.0.0.4"
+// h = ipHash.Select(pool, request)
+// if h != pool[1] {
+// t.Error("Expected ip hash policy host to be the second host.")
+// }
+
+// // we should get a healthy host if the original host is unhealthy and a
+// // healthy host is available
+// request.RemoteAddr = "172.0.0.1"
+// pool[1].Unhealthy = 1
+// h = ipHash.Select(pool, request)
+// if h != pool[2] {
+// t.Error("Expected ip hash policy host to be the third host.")
+// }
+
+// request.RemoteAddr = "172.0.0.2"
+// h = ipHash.Select(pool, request)
+// if h != pool[2] {
+// t.Error("Expected ip hash policy host to be the third host.")
+// }
+// pool[1].Unhealthy = 0
+
+// request.RemoteAddr = "172.0.0.3"
+// pool[2].Unhealthy = 1
+// h = ipHash.Select(pool, request)
+// if h != pool[0] {
+// t.Error("Expected ip hash policy host to be the first host.")
+// }
+// request.RemoteAddr = "172.0.0.4"
+// h = ipHash.Select(pool, request)
+// if h != pool[1] {
+// t.Error("Expected ip hash policy host to be the second host.")
+// }
+
+// // We should be able to resize the host pool and still be able to predict
+// // where a request will be routed with the same IP's used above
+// pool = []*UpstreamHost{
+// {
+// Name: workableServer.URL, // this should resolve (healthcheck test)
+// },
+// {
+// Name: "http://localhost:99998", // this shouldn't
+// },
+// }
+// pool = HostPool(pool)
+// request.RemoteAddr = "172.0.0.1:80"
+// h = ipHash.Select(pool, request)
+// if h != pool[0] {
+// t.Error("Expected ip hash policy host to be the first host.")
+// }
+// request.RemoteAddr = "172.0.0.2:80"
+// h = ipHash.Select(pool, request)
+// if h != pool[1] {
+// t.Error("Expected ip hash policy host to be the second host.")
+// }
+// request.RemoteAddr = "172.0.0.3:80"
+// h = ipHash.Select(pool, request)
+// if h != pool[0] {
+// t.Error("Expected ip hash policy host to be the first host.")
+// }
+// request.RemoteAddr = "172.0.0.4:80"
+// h = ipHash.Select(pool, request)
+// if h != pool[1] {
+// t.Error("Expected ip hash policy host to be the second host.")
+// }
+
+// // We should get nil when there are no healthy hosts
+// pool[0].Unhealthy = 1
+// pool[1].Unhealthy = 1
+// h = ipHash.Select(pool, request)
+// if h != nil {
+// t.Error("Expected ip hash policy host to be nil.")
+// }
+// }
+
+// func TestFirstPolicy(t *testing.T) {
+// pool := testPool()
+// firstPolicy := &First{}
+// req := httptest.NewRequest(http.MethodGet, "/", nil)
+
+// h := firstPolicy.Select(pool, req)
+// if h != pool[0] {
+// t.Error("Expected first policy host to be the first host.")
+// }
+
+// pool[0].Unhealthy = 1
+// h = firstPolicy.Select(pool, req)
+// if h != pool[1] {
+// t.Error("Expected first policy host to be the second host.")
+// }
+// }
+
+// func TestUriPolicy(t *testing.T) {
+// pool := testPool()
+// uriPolicy := &URIHash{}
+
+// request := httptest.NewRequest(http.MethodGet, "/test", nil)
+// h := uriPolicy.Select(pool, request)
+// if h != pool[0] {
+// t.Error("Expected uri policy host to be the first host.")
+// }
+
+// pool[0].Unhealthy = 1
+// h = uriPolicy.Select(pool, request)
+// if h != pool[1] {
+// t.Error("Expected uri policy host to be the first host.")
+// }
+
+// request = httptest.NewRequest(http.MethodGet, "/test_2", nil)
+// h = uriPolicy.Select(pool, request)
+// if h != pool[1] {
+// t.Error("Expected uri policy host to be the second host.")
+// }
+
+// // We should be able to resize the host pool and still be able to predict
+// // where a request will be routed with the same URI's used above
+// pool = []*UpstreamHost{
+// {
+// Name: workableServer.URL, // this should resolve (healthcheck test)
+// },
+// {
+// Name: "http://localhost:99998", // this shouldn't
+// },
+// }
+
+// request = httptest.NewRequest(http.MethodGet, "/test", nil)
+// h = uriPolicy.Select(pool, request)
+// if h != pool[0] {
+// t.Error("Expected uri policy host to be the first host.")
+// }
+
+// pool[0].Unhealthy = 1
+// h = uriPolicy.Select(pool, request)
+// if h != pool[1] {
+// t.Error("Expected uri policy host to be the first host.")
+// }
+
+// request = httptest.NewRequest(http.MethodGet, "/test_2", nil)
+// h = uriPolicy.Select(pool, request)
+// if h != pool[1] {
+// t.Error("Expected uri policy host to be the second host.")
+// }
+
+// pool[0].Unhealthy = 1
+// pool[1].Unhealthy = 1
+// h = uriPolicy.Select(pool, request)
+// if h != nil {
+// t.Error("Expected uri policy policy host to be nil.")
+// }
+// }
+
+// func TestHeaderPolicy(t *testing.T) {
+// pool := testPool()
+// tests := []struct {
+// Name string
+// Policy *Header
+// RequestHeaderName string
+// RequestHeaderValue string
+// NilHost bool
+// HostIndex int
+// }{
+// {"empty config", &Header{""}, "", "", true, 0},
+// {"empty config+header+value", &Header{""}, "Affinity", "somevalue", true, 0},
+// {"empty config+header", &Header{""}, "Affinity", "", true, 0},
+
+// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 1},
+// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 2},
+// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 0},
+
+// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue", false, 1},
+// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue2", false, 0},
+// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue3", false, 2},
+// {"hash route with empty value", &Header{"Affinity"}, "Affinity", "", false, 1},
+// }
+
+// for idx, test := range tests {
+// request, _ := http.NewRequest("GET", "/", nil)
+// if test.RequestHeaderName != "" {
+// request.Header.Add(test.RequestHeaderName, test.RequestHeaderValue)
+// }
+
+// host := test.Policy.Select(pool, request)
+// if test.NilHost && host != nil {
+// t.Errorf("%d: Expected host to be nil", idx)
+// }
+// if !test.NilHost && host == nil {
+// t.Errorf("%d: Did not expect host to be nil", idx)
+// }
+// if !test.NilHost && host != pool[test.HostIndex] {
+// t.Errorf("%d: Expected Header policy to be host %d", idx, test.HostIndex)
+// }
+// }
+// }
diff --git a/modules/caddyhttp/reverseproxy/upstream.go b/modules/caddyhttp/reverseproxy/upstream.go
deleted file mode 100755
index 1f0693e..0000000
--- a/modules/caddyhttp/reverseproxy/upstream.go
+++ /dev/null
@@ -1,450 +0,0 @@
-// Copyright 2015 Matthew Holt and The Caddy Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package reverseproxy implements a load-balanced reverse proxy.
-package reverseproxy
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "math/rand"
- "net"
- "net/http"
- "net/url"
- "strings"
- "sync"
- "sync/atomic"
- "time"
-
- "github.com/caddyserver/caddy/v2"
- "github.com/caddyserver/caddy/v2/modules/caddyhttp"
-)
-
-// CircuitBreaker defines the functionality of a circuit breaker module.
-type CircuitBreaker interface {
- Ok() bool
- RecordMetric(statusCode int, latency time.Duration)
-}
-
-type noopCircuitBreaker struct{}
-
-func (ncb noopCircuitBreaker) RecordMetric(statusCode int, latency time.Duration) {}
-func (ncb noopCircuitBreaker) Ok() bool {
- return true
-}
-
-const (
- // TypeBalanceRoundRobin represents the value to use for configuring a load balanced reverse proxy to use round robin load balancing.
- TypeBalanceRoundRobin = iota
-
- // TypeBalanceRandom represents the value to use for configuring a load balanced reverse proxy to use random load balancing.
- TypeBalanceRandom
-
- // TODO: add random with two choices
-
- // msgNoHealthyUpstreams is returned if there are no upstreams that are healthy to proxy a request to
- msgNoHealthyUpstreams = "No healthy upstreams."
-
- // by default perform health checks every 30 seconds
- defaultHealthCheckDur = time.Second * 30
-
- // used when an upstream is unhealthy, health checks can be configured to perform at a faster rate
- defaultFastHealthCheckDur = time.Second * 1
-)
-
-var (
- // defaultTransport is the default transport to use for the reverse proxy.
- defaultTransport = &http.Transport{
- Dial: (&net.Dialer{
- Timeout: 5 * time.Second,
- }).Dial,
- TLSHandshakeTimeout: 5 * time.Second,
- }
-
- // defaultHTTPClient is the default http client to use for the healthchecker.
- defaultHTTPClient = &http.Client{
- Timeout: time.Second * 10,
- Transport: defaultTransport,
- }
-
- // typeMap maps caddy load balance configuration to the internal representation of the loadbalance algorithm type.
- typeMap = map[string]int{
- "round_robin": TypeBalanceRoundRobin,
- "random": TypeBalanceRandom,
- }
-)
-
-// NewLoadBalancedReverseProxy returns a collection of Upstreams that are to be loadbalanced.
-func NewLoadBalancedReverseProxy(lb *LoadBalanced, ctx caddy.Context) error {
- // set defaults
- if lb.NoHealthyUpstreamsMessage == "" {
- lb.NoHealthyUpstreamsMessage = msgNoHealthyUpstreams
- }
-
- if lb.TryInterval == "" {
- lb.TryInterval = "20s"
- }
-
- // set request retry interval
- ti, err := time.ParseDuration(lb.TryInterval)
- if err != nil {
- return fmt.Errorf("NewLoadBalancedReverseProxy: %v", err.Error())
- }
- lb.tryInterval = ti
-
- // set load balance algorithm
- t, ok := typeMap[lb.LoadBalanceType]
- if !ok {
- t = TypeBalanceRandom
- }
- lb.loadBalanceType = t
-
- // setup each upstream
- var us []*upstream
- for _, uc := range lb.Upstreams {
- // pass the upstream decr and incr methods to keep track of unhealthy nodes
- nu, err := newUpstream(uc, lb.decrUnhealthy, lb.incrUnhealthy)
- if err != nil {
- return err
- }
-
- // setup any configured circuit breakers
- var cbModule = "http.handlers.reverse_proxy.circuit_breaker"
- var cb CircuitBreaker
-
- if uc.CircuitBreaker != nil {
- if _, err := caddy.GetModule(cbModule); err == nil {
- val, err := ctx.LoadModule(cbModule, uc.CircuitBreaker)
- if err == nil {
- cbv, ok := val.(CircuitBreaker)
- if ok {
- cb = cbv
- } else {
- fmt.Printf("\nerr: %v; cannot load circuit_breaker, using noop", err.Error())
- cb = noopCircuitBreaker{}
- }
- } else {
- fmt.Printf("\nerr: %v; cannot load circuit_breaker, using noop", err.Error())
- cb = noopCircuitBreaker{}
- }
- } else {
- fmt.Println("circuit_breaker module not loaded, using noop")
- cb = noopCircuitBreaker{}
- }
- } else {
- cb = noopCircuitBreaker{}
- }
- nu.CB = cb
-
- // start a healthcheck worker which will periodically check to see if an upstream is healthy
- // to proxy requests to.
- nu.healthChecker = NewHealthCheckWorker(nu, defaultHealthCheckDur, defaultHTTPClient)
-
- // TODO :- if path is empty why does this empty the entire Target?
- // nu.Target.Path = uc.HealthCheckPath
-
- nu.healthChecker.ScheduleChecks(nu.Target.String())
- lb.HealthCheckers = append(lb.HealthCheckers, nu.healthChecker)
-
- us = append(us, nu)
- }
-
- lb.upstreams = us
-
- return nil
-}
-
-// LoadBalanced represents a collection of upstream hosts that are loadbalanced. It
-// contains multiple features like health checking and circuit breaking functionality
-// for upstreams.
-type LoadBalanced struct {
- mu sync.Mutex
- numUnhealthy int32
- selectedServer int // used during round robin load balancing
- loadBalanceType int
- tryInterval time.Duration
- upstreams []*upstream
-
- // The following struct fields are set by caddy configuration.
- // TryInterval is the max duration for which request retrys will be performed for a request.
- TryInterval string `json:"try_interval,omitempty"`
-
- // Upstreams are the configs for upstream hosts
- Upstreams []*UpstreamConfig `json:"upstreams,omitempty"`
-
- // LoadBalanceType is the string representation of what loadbalancing algorithm to use. i.e. "random" or "round_robin".
- LoadBalanceType string `json:"load_balance_type,omitempty"`
-
- // NoHealthyUpstreamsMessage is returned as a response when there are no healthy upstreams to loadbalance to.
- NoHealthyUpstreamsMessage string `json:"no_healthy_upstreams_message,omitempty"`
-
- // TODO :- store healthcheckers as package level state where each upstream gets a single healthchecker
- // currently a healthchecker is created for each upstream defined, even if a healthchecker was previously created
- // for that upstream
- HealthCheckers []*HealthChecker `json:"health_checkers,omitempty"`
-}
-
-// Cleanup stops all health checkers on a loadbalanced reverse proxy.
-func (lb *LoadBalanced) Cleanup() error {
- for _, hc := range lb.HealthCheckers {
- hc.Stop()
- }
-
- return nil
-}
-
-// Provision sets up a new loadbalanced reverse proxy.
-func (lb *LoadBalanced) Provision(ctx caddy.Context) error {
- return NewLoadBalancedReverseProxy(lb, ctx)
-}
-
-// ServeHTTP implements the caddyhttp.MiddlewareHandler interface to
-// dispatch an HTTP request to the proper server.
-func (lb *LoadBalanced) ServeHTTP(w http.ResponseWriter, r *http.Request, _ caddyhttp.Handler) error {
- // ensure requests don't hang if an upstream does not respond or is not eventually healthy
- var u *upstream
- var done bool
-
- retryTimer := time.NewTicker(lb.tryInterval)
- defer retryTimer.Stop()
-
- go func() {
- select {
- case <-retryTimer.C:
- done = true
- }
- }()
-
- // keep trying to get an available upstream to process the request
- for {
- switch lb.loadBalanceType {
- case TypeBalanceRandom:
- u = lb.random()
- case TypeBalanceRoundRobin:
- u = lb.roundRobin()
- }
-
- // if we can't get an upstream and our retry interval has ended return an error response
- if u == nil && done {
- w.WriteHeader(http.StatusBadGateway)
- fmt.Fprint(w, lb.NoHealthyUpstreamsMessage)
-
- return fmt.Errorf(msgNoHealthyUpstreams)
- }
-
- // attempt to get an available upstream
- if u == nil {
- continue
- }
-
- start := time.Now()
-
- // if we get an error retry until we get a healthy upstream
- res, err := u.ReverseProxy.ServeHTTP(w, r)
- if err != nil {
- if err == context.Canceled {
- return nil
- }
-
- continue
- }
-
- // record circuit breaker metrics
- go u.CB.RecordMetric(res.StatusCode, time.Now().Sub(start))
-
- return nil
- }
-}
-
-// incrUnhealthy increments the amount of unhealthy nodes in a loadbalancer.
-func (lb *LoadBalanced) incrUnhealthy() {
- atomic.AddInt32(&lb.numUnhealthy, 1)
-}
-
-// decrUnhealthy decrements the amount of unhealthy nodes in a loadbalancer.
-func (lb *LoadBalanced) decrUnhealthy() {
- atomic.AddInt32(&lb.numUnhealthy, -1)
-}
-
-// roundRobin implements a round robin load balancing algorithm to select
-// which server to forward requests to.
-func (lb *LoadBalanced) roundRobin() *upstream {
- if atomic.LoadInt32(&lb.numUnhealthy) == int32(len(lb.upstreams)) {
- return nil
- }
-
- selected := lb.upstreams[lb.selectedServer]
-
- lb.mu.Lock()
- lb.selectedServer++
- if lb.selectedServer >= len(lb.upstreams) {
- lb.selectedServer = 0
- }
- lb.mu.Unlock()
-
- if selected.IsHealthy() && selected.CB.Ok() {
- return selected
- }
-
- return nil
-}
-
-// random implements a random server selector for load balancing.
-func (lb *LoadBalanced) random() *upstream {
- if atomic.LoadInt32(&lb.numUnhealthy) == int32(len(lb.upstreams)) {
- return nil
- }
-
- n := rand.Int() % len(lb.upstreams)
- selected := lb.upstreams[n]
-
- if selected.IsHealthy() && selected.CB.Ok() {
- return selected
- }
-
- return nil
-}
-
-// UpstreamConfig represents the config of an upstream.
-type UpstreamConfig struct {
- // Host is the host name of the upstream server.
- Host string `json:"host,omitempty"`
-
- // FastHealthCheckDuration is the duration for which a health check is performed when a node is considered unhealthy.
- FastHealthCheckDuration string `json:"fast_health_check_duration,omitempty"`
-
- CircuitBreaker json.RawMessage `json:"circuit_breaker,omitempty"`
-
- // // CircuitBreakerConfig is the config passed to setup a circuit breaker.
- // CircuitBreakerConfig *circuitbreaker.Config `json:"circuit_breaker,omitempty"`
- circuitbreaker CircuitBreaker
-
- // HealthCheckDuration is the default duration for which a health check is performed.
- HealthCheckDuration string `json:"health_check_duration,omitempty"`
-
- // HealthCheckPath is the path at the upstream host to use for healthchecks.
- HealthCheckPath string `json:"health_check_path,omitempty"`
-}
-
-// upstream represents an upstream host.
-type upstream struct {
- Healthy int32 // 0 = false, 1 = true
- Target *url.URL
- ReverseProxy *ReverseProxy
- Incr func()
- Decr func()
- CB CircuitBreaker
- healthChecker *HealthChecker
- healthCheckDur time.Duration
- fastHealthCheckDur time.Duration
-}
-
-// newUpstream returns a new upstream.
-func newUpstream(uc *UpstreamConfig, d func(), i func()) (*upstream, error) {
- host := strings.TrimSpace(uc.Host)
- protoIdx := strings.Index(host, "://")
- if protoIdx == -1 || len(host[:protoIdx]) == 0 {
- return nil, fmt.Errorf("protocol is required for host")
- }
-
- hostURL, err := url.Parse(host)
- if err != nil {
- return nil, err
- }
-
- // parse healthcheck durations
- hcd, err := time.ParseDuration(uc.HealthCheckDuration)
- if err != nil {
- hcd = defaultHealthCheckDur
- }
-
- fhcd, err := time.ParseDuration(uc.FastHealthCheckDuration)
- if err != nil {
- fhcd = defaultFastHealthCheckDur
- }
-
- u := upstream{
- healthCheckDur: hcd,
- fastHealthCheckDur: fhcd,
- Target: hostURL,
- Decr: d,
- Incr: i,
- Healthy: int32(0), // assume is unhealthy on start
- }
-
- u.ReverseProxy = newReverseProxy(hostURL, u.SetHealthiness)
- return &u, nil
-}
-
-// SetHealthiness sets whether an upstream is healthy or not. The health check worker is updated to
-// perform checks faster if a node is unhealthy.
-func (u *upstream) SetHealthiness(ok bool) {
- h := atomic.LoadInt32(&u.Healthy)
- var wasHealthy bool
- if h == 1 {
- wasHealthy = true
- } else {
- wasHealthy = false
- }
-
- if ok {
- u.healthChecker.Ticker = time.NewTicker(u.healthCheckDur)
-
- if !wasHealthy {
- atomic.AddInt32(&u.Healthy, 1)
- u.Decr()
- }
- } else {
- u.healthChecker.Ticker = time.NewTicker(u.fastHealthCheckDur)
-
- if wasHealthy {
- atomic.AddInt32(&u.Healthy, -1)
- u.Incr()
- }
- }
-}
-
-// IsHealthy returns whether an Upstream is healthy or not.
-func (u *upstream) IsHealthy() bool {
- i := atomic.LoadInt32(&u.Healthy)
- if i == 1 {
- return true
- }
-
- return false
-}
-
-// newReverseProxy returns a new reverse proxy handler.
-func newReverseProxy(target *url.URL, setHealthiness func(bool)) *ReverseProxy {
- errorHandler := func(w http.ResponseWriter, r *http.Request, err error) {
- // we don't need to worry about cancelled contexts since this doesn't necessarilly mean that
- // the upstream is unhealthy.
- if err != context.Canceled {
- setHealthiness(false)
- }
- }
-
- rp := NewSingleHostReverseProxy(target)
- rp.ErrorHandler = errorHandler
- rp.Transport = defaultTransport // use default transport that times out in 5 seconds
- return rp
-}
-
-// Interface guards
-var (
- _ caddyhttp.MiddlewareHandler = (*LoadBalanced)(nil)
- _ caddy.Provisioner = (*LoadBalanced)(nil)
- _ caddy.CleanerUpper = (*LoadBalanced)(nil)
-)