From 026df7c5cb33331d223afc6a9599274e8c89dfd9 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Mon, 2 Sep 2019 22:01:02 -0600 Subject: reverse_proxy: WIP refactor and support for FastCGI --- modules/caddyhttp/caddyhttp.go | 14 + modules/caddyhttp/matchers.go | 5 +- modules/caddyhttp/reverseproxy/fastcgi/client.go | 578 ++++++++++++++ .../caddyhttp/reverseproxy/fastcgi/client_test.go | 301 ++++++++ modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go | 342 +++++++++ modules/caddyhttp/reverseproxy/healthchecker.go | 86 --- modules/caddyhttp/reverseproxy/httptransport.go | 133 ++++ modules/caddyhttp/reverseproxy/module.go | 53 -- modules/caddyhttp/reverseproxy/reverseproxy.go | 846 ++++++++++++++------- .../caddyhttp/reverseproxy/selectionpolicies.go | 351 +++++++++ .../reverseproxy/selectionpolicies_test.go | 363 +++++++++ modules/caddyhttp/reverseproxy/upstream.go | 450 ----------- 12 files changed, 2649 insertions(+), 873 deletions(-) create mode 100644 modules/caddyhttp/reverseproxy/fastcgi/client.go create mode 100644 modules/caddyhttp/reverseproxy/fastcgi/client_test.go create mode 100644 modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go delete mode 100755 modules/caddyhttp/reverseproxy/healthchecker.go create mode 100644 modules/caddyhttp/reverseproxy/httptransport.go delete mode 100755 modules/caddyhttp/reverseproxy/module.go mode change 100755 => 100644 modules/caddyhttp/reverseproxy/reverseproxy.go create mode 100644 modules/caddyhttp/reverseproxy/selectionpolicies.go create mode 100644 modules/caddyhttp/reverseproxy/selectionpolicies_test.go delete mode 100755 modules/caddyhttp/reverseproxy/upstream.go (limited to 'modules') diff --git a/modules/caddyhttp/caddyhttp.go b/modules/caddyhttp/caddyhttp.go index b4b1ec6..300b5fd 100644 --- a/modules/caddyhttp/caddyhttp.go +++ b/modules/caddyhttp/caddyhttp.go @@ -518,6 +518,20 @@ func (ws WeakString) String() string { return string(ws) } +// StatusCodeMatches returns true if a real HTTP status code matches +// the configured status code, which may be either a real HTTP status +// code or an integer representing a class of codes (e.g. 4 for all +// 4xx statuses). +func StatusCodeMatches(actual, configured int) bool { + if actual == configured { + return true + } + if configured < 100 && actual >= configured*100 && actual < (configured+1)*100 { + return true + } + return false +} + const ( // DefaultHTTPPort is the default port for HTTP. DefaultHTTPPort = 80 diff --git a/modules/caddyhttp/matchers.go b/modules/caddyhttp/matchers.go index 0dac151..2ddefd0 100644 --- a/modules/caddyhttp/matchers.go +++ b/modules/caddyhttp/matchers.go @@ -616,10 +616,7 @@ func (rm ResponseMatcher) matchStatusCode(statusCode int) bool { return true } for _, code := range rm.StatusCode { - if statusCode == code { - return true - } - if code < 100 && statusCode >= code*100 && statusCode < (code+1)*100 { + if StatusCodeMatches(statusCode, code) { return true } } diff --git a/modules/caddyhttp/reverseproxy/fastcgi/client.go b/modules/caddyhttp/reverseproxy/fastcgi/client.go new file mode 100644 index 0000000..ae0de00 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/fastcgi/client.go @@ -0,0 +1,578 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Forked Jan. 2015 from http://bitbucket.org/PinIdea/fcgi_client +// (which is forked from https://code.google.com/p/go-fastcgi-client/). +// This fork contains several fixes and improvements by Matt Holt and +// other contributors to the Caddy project. + +// Copyright 2012 Junqing Tan and The Go Authors +// Use of this source code is governed by a BSD-style +// Part of source code is from Go fcgi package + +package fastcgi + +import ( + "bufio" + "bytes" + "context" + "encoding/binary" + "errors" + "io" + "io/ioutil" + "mime/multipart" + "net" + "net/http" + "net/http/httputil" + "net/textproto" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" +) + +// FCGIListenSockFileno describes listen socket file number. +const FCGIListenSockFileno uint8 = 0 + +// FCGIHeaderLen describes header length. +const FCGIHeaderLen uint8 = 8 + +// Version1 describes the version. +const Version1 uint8 = 1 + +// FCGINullRequestID describes the null request ID. +const FCGINullRequestID uint8 = 0 + +// FCGIKeepConn describes keep connection mode. +const FCGIKeepConn uint8 = 1 + +const ( + // BeginRequest is the begin request flag. + BeginRequest uint8 = iota + 1 + // AbortRequest is the abort request flag. + AbortRequest + // EndRequest is the end request flag. + EndRequest + // Params is the parameters flag. + Params + // Stdin is the standard input flag. + Stdin + // Stdout is the standard output flag. + Stdout + // Stderr is the standard error flag. + Stderr + // Data is the data flag. + Data + // GetValues is the get values flag. + GetValues + // GetValuesResult is the get values result flag. + GetValuesResult + // UnknownType is the unknown type flag. + UnknownType + // MaxType is the maximum type flag. + MaxType = UnknownType +) + +const ( + // Responder is the responder flag. + Responder uint8 = iota + 1 + // Authorizer is the authorizer flag. + Authorizer + // Filter is the filter flag. + Filter +) + +const ( + // RequestComplete is the completed request flag. + RequestComplete uint8 = iota + // CantMultiplexConns is the multiplexed connections flag. + CantMultiplexConns + // Overloaded is the overloaded flag. + Overloaded + // UnknownRole is the unknown role flag. + UnknownRole +) + +const ( + // MaxConns is the maximum connections flag. + MaxConns string = "MAX_CONNS" + // MaxRequests is the maximum requests flag. + MaxRequests string = "MAX_REQS" + // MultiplexConns is the multiplex connections flag. + MultiplexConns string = "MPXS_CONNS" +) + +const ( + maxWrite = 65500 // 65530 may work, but for compatibility + maxPad = 255 +) + +type header struct { + Version uint8 + Type uint8 + ID uint16 + ContentLength uint16 + PaddingLength uint8 + Reserved uint8 +} + +// for padding so we don't have to allocate all the time +// not synchronized because we don't care what the contents are +var pad [maxPad]byte + +func (h *header) init(recType uint8, reqID uint16, contentLength int) { + h.Version = 1 + h.Type = recType + h.ID = reqID + h.ContentLength = uint16(contentLength) + h.PaddingLength = uint8(-contentLength & 7) +} + +type record struct { + h header + rbuf []byte +} + +func (rec *record) read(r io.Reader) (buf []byte, err error) { + if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil { + return + } + if rec.h.Version != 1 { + err = errors.New("fcgi: invalid header version") + return + } + if rec.h.Type == EndRequest { + err = io.EOF + return + } + n := int(rec.h.ContentLength) + int(rec.h.PaddingLength) + if len(rec.rbuf) < n { + rec.rbuf = make([]byte, n) + } + if _, err = io.ReadFull(r, rec.rbuf[:n]); err != nil { + return + } + buf = rec.rbuf[:int(rec.h.ContentLength)] + + return +} + +// FCGIClient implements a FastCGI client, which is a standard for +// interfacing external applications with Web servers. +type FCGIClient struct { + mutex sync.Mutex + rwc io.ReadWriteCloser + h header + buf bytes.Buffer + stderr bytes.Buffer + keepAlive bool + reqID uint16 +} + +// DialWithDialerContext connects to the fcgi responder at the specified network address, using custom net.Dialer +// and a context. +// See func net.Dial for a description of the network and address parameters. +func DialWithDialerContext(ctx context.Context, network, address string, dialer net.Dialer) (fcgi *FCGIClient, err error) { + var conn net.Conn + conn, err = dialer.DialContext(ctx, network, address) + if err != nil { + return + } + + fcgi = &FCGIClient{ + rwc: conn, + keepAlive: false, + reqID: 1, + } + + return +} + +// DialContext is like Dial but passes ctx to dialer.Dial. +func DialContext(ctx context.Context, network, address string) (fcgi *FCGIClient, err error) { + // TODO: why not set timeout here? + return DialWithDialerContext(ctx, network, address, net.Dialer{}) +} + +// Dial connects to the fcgi responder at the specified network address, using default net.Dialer. +// See func net.Dial for a description of the network and address parameters. +func Dial(network, address string) (fcgi *FCGIClient, err error) { + return DialContext(context.Background(), network, address) +} + +// Close closes fcgi connection +func (c *FCGIClient) Close() { + c.rwc.Close() +} + +func (c *FCGIClient) writeRecord(recType uint8, content []byte) (err error) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.buf.Reset() + c.h.init(recType, c.reqID, len(content)) + if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil { + return err + } + if _, err := c.buf.Write(content); err != nil { + return err + } + if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil { + return err + } + _, err = c.rwc.Write(c.buf.Bytes()) + return err +} + +func (c *FCGIClient) writeBeginRequest(role uint16, flags uint8) error { + b := [8]byte{byte(role >> 8), byte(role), flags} + return c.writeRecord(BeginRequest, b[:]) +} + +func (c *FCGIClient) writeEndRequest(appStatus int, protocolStatus uint8) error { + b := make([]byte, 8) + binary.BigEndian.PutUint32(b, uint32(appStatus)) + b[4] = protocolStatus + return c.writeRecord(EndRequest, b) +} + +func (c *FCGIClient) writePairs(recType uint8, pairs map[string]string) error { + w := newWriter(c, recType) + b := make([]byte, 8) + nn := 0 + for k, v := range pairs { + m := 8 + len(k) + len(v) + if m > maxWrite { + // param data size exceed 65535 bytes" + vl := maxWrite - 8 - len(k) + v = v[:vl] + } + n := encodeSize(b, uint32(len(k))) + n += encodeSize(b[n:], uint32(len(v))) + m = n + len(k) + len(v) + if (nn + m) > maxWrite { + w.Flush() + nn = 0 + } + nn += m + if _, err := w.Write(b[:n]); err != nil { + return err + } + if _, err := w.WriteString(k); err != nil { + return err + } + if _, err := w.WriteString(v); err != nil { + return err + } + } + w.Close() + return nil +} + +func encodeSize(b []byte, size uint32) int { + if size > 127 { + size |= 1 << 31 + binary.BigEndian.PutUint32(b, size) + return 4 + } + b[0] = byte(size) + return 1 +} + +// bufWriter encapsulates bufio.Writer but also closes the underlying stream when +// Closed. +type bufWriter struct { + closer io.Closer + *bufio.Writer +} + +func (w *bufWriter) Close() error { + if err := w.Writer.Flush(); err != nil { + w.closer.Close() + return err + } + return w.closer.Close() +} + +func newWriter(c *FCGIClient, recType uint8) *bufWriter { + s := &streamWriter{c: c, recType: recType} + w := bufio.NewWriterSize(s, maxWrite) + return &bufWriter{s, w} +} + +// streamWriter abstracts out the separation of a stream into discrete records. +// It only writes maxWrite bytes at a time. +type streamWriter struct { + c *FCGIClient + recType uint8 +} + +func (w *streamWriter) Write(p []byte) (int, error) { + nn := 0 + for len(p) > 0 { + n := len(p) + if n > maxWrite { + n = maxWrite + } + if err := w.c.writeRecord(w.recType, p[:n]); err != nil { + return nn, err + } + nn += n + p = p[n:] + } + return nn, nil +} + +func (w *streamWriter) Close() error { + // send empty record to close the stream + return w.c.writeRecord(w.recType, nil) +} + +type streamReader struct { + c *FCGIClient + buf []byte +} + +func (w *streamReader) Read(p []byte) (n int, err error) { + + if len(p) > 0 { + if len(w.buf) == 0 { + + // filter outputs for error log + for { + rec := &record{} + var buf []byte + buf, err = rec.read(w.c.rwc) + if err != nil { + return + } + // standard error output + if rec.h.Type == Stderr { + w.c.stderr.Write(buf) + continue + } + w.buf = buf + break + } + } + + n = len(p) + if n > len(w.buf) { + n = len(w.buf) + } + copy(p, w.buf[:n]) + w.buf = w.buf[n:] + } + + return +} + +// Do made the request and returns a io.Reader that translates the data read +// from fcgi responder out of fcgi packet before returning it. +func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err error) { + err = c.writeBeginRequest(uint16(Responder), 0) + if err != nil { + return + } + + err = c.writePairs(Params, p) + if err != nil { + return + } + + body := newWriter(c, Stdin) + if req != nil { + _, _ = io.Copy(body, req) + } + body.Close() + + r = &streamReader{c: c} + return +} + +// clientCloser is a io.ReadCloser. It wraps a io.Reader with a Closer +// that closes FCGIClient connection. +type clientCloser struct { + *FCGIClient + io.Reader +} + +func (f clientCloser) Close() error { return f.rwc.Close() } + +// Request returns a HTTP Response with Header and Body +// from fcgi responder +func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Response, err error) { + r, err := c.Do(p, req) + if err != nil { + return + } + + rb := bufio.NewReader(r) + tp := textproto.NewReader(rb) + resp = new(http.Response) + + // Parse the response headers. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil && err != io.EOF { + return + } + resp.Header = http.Header(mimeHeader) + + if resp.Header.Get("Status") != "" { + statusParts := strings.SplitN(resp.Header.Get("Status"), " ", 2) + resp.StatusCode, err = strconv.Atoi(statusParts[0]) + if err != nil { + return + } + if len(statusParts) > 1 { + resp.Status = statusParts[1] + } + + } else { + resp.StatusCode = http.StatusOK + } + + // TODO: fixTransferEncoding ? + resp.TransferEncoding = resp.Header["Transfer-Encoding"] + resp.ContentLength, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + + if chunked(resp.TransferEncoding) { + resp.Body = clientCloser{c, httputil.NewChunkedReader(rb)} + } else { + resp.Body = clientCloser{c, ioutil.NopCloser(rb)} + } + return +} + +// Get issues a GET request to the fcgi responder. +func (c *FCGIClient) Get(p map[string]string, body io.Reader, l int64) (resp *http.Response, err error) { + + p["REQUEST_METHOD"] = "GET" + p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10) + + return c.Request(p, body) +} + +// Head issues a HEAD request to the fcgi responder. +func (c *FCGIClient) Head(p map[string]string) (resp *http.Response, err error) { + + p["REQUEST_METHOD"] = "HEAD" + p["CONTENT_LENGTH"] = "0" + + return c.Request(p, nil) +} + +// Options issues an OPTIONS request to the fcgi responder. +func (c *FCGIClient) Options(p map[string]string) (resp *http.Response, err error) { + + p["REQUEST_METHOD"] = "OPTIONS" + p["CONTENT_LENGTH"] = "0" + + return c.Request(p, nil) +} + +// Post issues a POST request to the fcgi responder. with request body +// in the format that bodyType specified +func (c *FCGIClient) Post(p map[string]string, method string, bodyType string, body io.Reader, l int64) (resp *http.Response, err error) { + if p == nil { + p = make(map[string]string) + } + + p["REQUEST_METHOD"] = strings.ToUpper(method) + + if len(p["REQUEST_METHOD"]) == 0 || p["REQUEST_METHOD"] == "GET" { + p["REQUEST_METHOD"] = "POST" + } + + p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10) + if len(bodyType) > 0 { + p["CONTENT_TYPE"] = bodyType + } else { + p["CONTENT_TYPE"] = "application/x-www-form-urlencoded" + } + + return c.Request(p, body) +} + +// PostForm issues a POST to the fcgi responder, with form +// as a string key to a list values (url.Values) +func (c *FCGIClient) PostForm(p map[string]string, data url.Values) (resp *http.Response, err error) { + body := bytes.NewReader([]byte(data.Encode())) + return c.Post(p, "POST", "application/x-www-form-urlencoded", body, int64(body.Len())) +} + +// PostFile issues a POST to the fcgi responder in multipart(RFC 2046) standard, +// with form as a string key to a list values (url.Values), +// and/or with file as a string key to a list file path. +func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[string]string) (resp *http.Response, err error) { + buf := &bytes.Buffer{} + writer := multipart.NewWriter(buf) + bodyType := writer.FormDataContentType() + + for key, val := range data { + for _, v0 := range val { + err = writer.WriteField(key, v0) + if err != nil { + return + } + } + } + + for key, val := range file { + fd, e := os.Open(val) + if e != nil { + return nil, e + } + defer fd.Close() + + part, e := writer.CreateFormFile(key, filepath.Base(val)) + if e != nil { + return nil, e + } + _, err = io.Copy(part, fd) + if err != nil { + return + } + } + + err = writer.Close() + if err != nil { + return + } + + return c.Post(p, "POST", bodyType, buf, int64(buf.Len())) +} + +// SetReadTimeout sets the read timeout for future calls that read from the +// fcgi responder. A zero value for t means no timeout will be set. +func (c *FCGIClient) SetReadTimeout(t time.Duration) error { + if conn, ok := c.rwc.(net.Conn); ok && t != 0 { + return conn.SetReadDeadline(time.Now().Add(t)) + } + return nil +} + +// SetWriteTimeout sets the write timeout for future calls that send data to +// the fcgi responder. A zero value for t means no timeout will be set. +func (c *FCGIClient) SetWriteTimeout(t time.Duration) error { + if conn, ok := c.rwc.(net.Conn); ok && t != 0 { + return conn.SetWriteDeadline(time.Now().Add(t)) + } + return nil +} + +// Checks whether chunked is part of the encodings stack +func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } diff --git a/modules/caddyhttp/reverseproxy/fastcgi/client_test.go b/modules/caddyhttp/reverseproxy/fastcgi/client_test.go new file mode 100644 index 0000000..c090f3c --- /dev/null +++ b/modules/caddyhttp/reverseproxy/fastcgi/client_test.go @@ -0,0 +1,301 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// NOTE: These tests were adapted from the original +// repository from which this package was forked. +// The tests are slow (~10s) and in dire need of rewriting. +// As such, the tests have been disabled to speed up +// automated builds until they can be properly written. + +package fastcgi + +import ( + "bytes" + "crypto/md5" + "encoding/binary" + "fmt" + "io" + "io/ioutil" + "log" + "math/rand" + "net" + "net/http" + "net/http/fcgi" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + "time" +) + +// test fcgi protocol includes: +// Get, Post, Post in multipart/form-data, and Post with files +// each key should be the md5 of the value or the file uploaded +// specify remote fcgi responder ip:port to test with php +// test failed if the remote fcgi(script) failed md5 verification +// and output "FAILED" in response +const ( + scriptFile = "/tank/www/fcgic_test.php" + //ipPort = "remote-php-serv:59000" + ipPort = "127.0.0.1:59000" +) + +var globalt *testing.T + +type FastCGIServer struct{} + +func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + + if err := req.ParseMultipartForm(100000000); err != nil { + log.Printf("[ERROR] failed to parse: %v", err) + } + + stat := "PASSED" + fmt.Fprintln(resp, "-") + fileNum := 0 + { + length := 0 + for k0, v0 := range req.Form { + h := md5.New() + _, _ = io.WriteString(h, v0[0]) + _md5 := fmt.Sprintf("%x", h.Sum(nil)) + + length += len(k0) + length += len(v0[0]) + + // echo error when key != _md5(val) + if _md5 != k0 { + fmt.Fprintln(resp, "server:err ", _md5, k0) + stat = "FAILED" + } + } + if req.MultipartForm != nil { + fileNum = len(req.MultipartForm.File) + for kn, fns := range req.MultipartForm.File { + //fmt.Fprintln(resp, "server:filekey ", kn ) + length += len(kn) + for _, f := range fns { + fd, err := f.Open() + if err != nil { + log.Println("server:", err) + return + } + h := md5.New() + l0, err := io.Copy(h, fd) + if err != nil { + log.Println(err) + return + } + length += int(l0) + defer fd.Close() + md5 := fmt.Sprintf("%x", h.Sum(nil)) + //fmt.Fprintln(resp, "server:filemd5 ", md5 ) + + if kn != md5 { + fmt.Fprintln(resp, "server:err ", md5, kn) + stat = "FAILED" + } + //fmt.Fprintln(resp, "server:filename ", f.Filename ) + } + } + } + + fmt.Fprintln(resp, "server:got data length", length) + } + fmt.Fprintln(resp, "-"+stat+"-POST(", len(req.Form), ")-FILE(", fileNum, ")--") +} + +func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) { + fcgi, err := Dial("tcp", ipPort) + if err != nil { + log.Println("err:", err) + return + } + + length := 0 + + var resp *http.Response + switch reqType { + case 0: + if len(data) > 0 { + length = len(data) + rd := bytes.NewReader(data) + resp, err = fcgi.Post(fcgiParams, "", "", rd, int64(rd.Len())) + } else if len(posts) > 0 { + values := url.Values{} + for k, v := range posts { + values.Set(k, v) + length += len(k) + 2 + len(v) + } + resp, err = fcgi.PostForm(fcgiParams, values) + } else { + rd := bytes.NewReader(data) + resp, err = fcgi.Get(fcgiParams, rd, int64(rd.Len())) + } + + default: + values := url.Values{} + for k, v := range posts { + values.Set(k, v) + length += len(k) + 2 + len(v) + } + + for k, v := range files { + fi, _ := os.Lstat(v) + length += len(k) + int(fi.Size()) + } + resp, err = fcgi.PostFile(fcgiParams, values, files) + } + + if err != nil { + log.Println("err:", err) + return + } + + defer resp.Body.Close() + content, _ = ioutil.ReadAll(resp.Body) + + log.Println("c: send data length ≈", length, string(content)) + fcgi.Close() + time.Sleep(1 * time.Second) + + if bytes.Contains(content, []byte("FAILED")) { + globalt.Error("Server return failed message") + } + + return +} + +func generateRandFile(size int) (p string, m string) { + + p = filepath.Join(os.TempDir(), "fcgict"+strconv.Itoa(rand.Int())) + + // open output file + fo, err := os.Create(p) + if err != nil { + panic(err) + } + // close fo on exit and check for its returned error + defer func() { + if err := fo.Close(); err != nil { + panic(err) + } + }() + + h := md5.New() + for i := 0; i < size/16; i++ { + buf := make([]byte, 16) + binary.PutVarint(buf, rand.Int63()) + if _, err := fo.Write(buf); err != nil { + log.Printf("[ERROR] failed to write buffer: %v\n", err) + } + if _, err := h.Write(buf); err != nil { + log.Printf("[ERROR] failed to write buffer: %v\n", err) + } + } + m = fmt.Sprintf("%x", h.Sum(nil)) + return +} + +func DisabledTest(t *testing.T) { + // TODO: test chunked reader + globalt = t + + rand.Seed(time.Now().UTC().UnixNano()) + + // server + go func() { + listener, err := net.Listen("tcp", ipPort) + if err != nil { + log.Println("listener creation failed: ", err) + } + + srv := new(FastCGIServer) + if err := fcgi.Serve(listener, srv); err != nil { + log.Print("[ERROR] failed to start server: ", err) + } + }() + + time.Sleep(1 * time.Second) + + // init + fcgiParams := make(map[string]string) + fcgiParams["REQUEST_METHOD"] = "GET" + fcgiParams["SERVER_PROTOCOL"] = "HTTP/1.1" + //fcgi_params["GATEWAY_INTERFACE"] = "CGI/1.1" + fcgiParams["SCRIPT_FILENAME"] = scriptFile + + // simple GET + log.Println("test:", "get") + sendFcgi(0, fcgiParams, nil, nil, nil) + + // simple post data + log.Println("test:", "post") + sendFcgi(0, fcgiParams, []byte("c4ca4238a0b923820dcc509a6f75849b=1&7b8b965ad4bca0e41ab51de7b31363a1=n"), nil, nil) + + log.Println("test:", "post data (more than 60KB)") + data := "" + for i := 0x00; i < 0xff; i++ { + v0 := strings.Repeat(string(i), 256) + h := md5.New() + _, _ = io.WriteString(h, v0) + k0 := fmt.Sprintf("%x", h.Sum(nil)) + data += k0 + "=" + url.QueryEscape(v0) + "&" + } + sendFcgi(0, fcgiParams, []byte(data), nil, nil) + + log.Println("test:", "post form (use url.Values)") + p0 := make(map[string]string, 1) + p0["c4ca4238a0b923820dcc509a6f75849b"] = "1" + p0["7b8b965ad4bca0e41ab51de7b31363a1"] = "n" + sendFcgi(1, fcgiParams, nil, p0, nil) + + log.Println("test:", "post forms (256 keys, more than 1MB)") + p1 := make(map[string]string, 1) + for i := 0x00; i < 0xff; i++ { + v0 := strings.Repeat(string(i), 4096) + h := md5.New() + _, _ = io.WriteString(h, v0) + k0 := fmt.Sprintf("%x", h.Sum(nil)) + p1[k0] = v0 + } + sendFcgi(1, fcgiParams, nil, p1, nil) + + log.Println("test:", "post file (1 file, 500KB)) ") + f0 := make(map[string]string, 1) + path0, m0 := generateRandFile(500000) + f0[m0] = path0 + sendFcgi(1, fcgiParams, nil, p1, f0) + + log.Println("test:", "post multiple files (2 files, 5M each) and forms (256 keys, more than 1MB data") + path1, m1 := generateRandFile(5000000) + f0[m1] = path1 + sendFcgi(1, fcgiParams, nil, p1, f0) + + log.Println("test:", "post only files (2 files, 5M each)") + sendFcgi(1, fcgiParams, nil, nil, f0) + + log.Println("test:", "post only 1 file") + delete(f0, "m0") + sendFcgi(1, fcgiParams, nil, nil, f0) + + if err := os.Remove(path0); err != nil { + log.Println("[ERROR] failed to remove path: ", err) + } + if err := os.Remove(path1); err != nil { + log.Println("[ERROR] failed to remove path: ", err) + } +} diff --git a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go new file mode 100644 index 0000000..32f094b --- /dev/null +++ b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go @@ -0,0 +1,342 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fastcgi + +import ( + "context" + "crypto/tls" + "fmt" + "net/http" + "path" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/caddyserver/caddy/v2/modules/caddytls" + + "github.com/caddyserver/caddy/v2" +) + +func init() { + caddy.RegisterModule(Transport{}) +} + +type Transport struct { + ////////////////////////////// + // TODO: taken from v1 Handler type + + SoftwareName string + SoftwareVersion string + ServerName string + ServerPort string + + ////////////////////////// + // TODO: taken from v1 Rule type + + // The base path to match. Required. + // Path string + + // upstream load balancer + // balancer + + // Always process files with this extension with fastcgi. + // Ext string + + // Use this directory as the fastcgi root directory. Defaults to the root + // directory of the parent virtual host. + Root string + + // The path in the URL will be split into two, with the first piece ending + // with the value of SplitPath. The first piece will be assumed as the + // actual resource (CGI script) name, and the second piece will be set to + // PATH_INFO for the CGI script to use. + SplitPath string + + // If the URL ends with '/' (which indicates a directory), these index + // files will be tried instead. + IndexFiles []string + + // Environment Variables + EnvVars [][2]string + + // Ignored paths + IgnoredSubPaths []string + + // The duration used to set a deadline when connecting to an upstream. + DialTimeout time.Duration + + // The duration used to set a deadline when reading from the FastCGI server. + ReadTimeout time.Duration + + // The duration used to set a deadline when sending to the FastCGI server. + WriteTimeout time.Duration +} + +// CaddyModule returns the Caddy module information. +func (Transport) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.transport.fastcgi", + New: func() caddy.Module { return new(Transport) }, + } +} + +func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) { + // Create environment for CGI script + env, err := t.buildEnv(r) + if err != nil { + return nil, fmt.Errorf("building environment: %v", err) + } + + // TODO: + // Connect to FastCGI gateway + // address, err := f.Address() + // if err != nil { + // return http.StatusBadGateway, err + // } + // network, address := parseAddress(address) + network, address := "tcp", r.URL.Host // TODO: + + ctx := context.Background() + if t.DialTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, t.DialTimeout) + defer cancel() + } + + fcgiBackend, err := DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("dialing backend: %v", err) + } + // fcgiBackend is closed when response body is closed (see clientCloser) + + // read/write timeouts + if err := fcgiBackend.SetReadTimeout(t.ReadTimeout); err != nil { + return nil, fmt.Errorf("setting read timeout: %v", err) + } + if err := fcgiBackend.SetWriteTimeout(t.WriteTimeout); err != nil { + return nil, fmt.Errorf("setting write timeout: %v", err) + } + + var resp *http.Response + + var contentLength int64 + // if ContentLength is already set + if r.ContentLength > 0 { + contentLength = r.ContentLength + } else { + contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64) + } + switch r.Method { + case "HEAD": + resp, err = fcgiBackend.Head(env) + case "GET": + resp, err = fcgiBackend.Get(env, r.Body, contentLength) + case "OPTIONS": + resp, err = fcgiBackend.Options(env) + default: + resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength) + } + + // TODO: + return resp, err + + // Stuff brought over from v1 that might not be necessary here: + + // if resp != nil && resp.Body != nil { + // defer resp.Body.Close() + // } + + // if err != nil { + // if err, ok := err.(net.Error); ok && err.Timeout() { + // return http.StatusGatewayTimeout, err + // } else if err != io.EOF { + // return http.StatusBadGateway, err + // } + // } + + // // Write response header + // writeHeader(w, resp) + + // // Write the response body + // _, err = io.Copy(w, resp.Body) + // if err != nil { + // return http.StatusBadGateway, err + // } + + // // Log any stderr output from upstream + // if fcgiBackend.stderr.Len() != 0 { + // // Remove trailing newline, error logger already does this. + // err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n")) + // } + + // // Normally we would return the status code if it is an error status (>= 400), + // // however, upstream FastCGI apps don't know about our contract and have + // // probably already written an error page. So we just return 0, indicating + // // that the response body is already written. However, we do return any + // // error value so it can be logged. + // // Note that the proxy middleware works the same way, returning status=0. + // return 0, err +} + +// buildEnv returns a set of CGI environment variables for the request. +func (t Transport) buildEnv(r *http.Request) (map[string]string, error) { + var env map[string]string + + // Separate remote IP and port; more lenient than net.SplitHostPort + var ip, port string + if idx := strings.LastIndex(r.RemoteAddr, ":"); idx > -1 { + ip = r.RemoteAddr[:idx] + port = r.RemoteAddr[idx+1:] + } else { + ip = r.RemoteAddr + } + + // Remove [] from IPv6 addresses + ip = strings.Replace(ip, "[", "", 1) + ip = strings.Replace(ip, "]", "", 1) + + // TODO: respect index files? or leave that to matcher/rewrite (I prefer that)? + fpath := r.URL.Path + + // Split path in preparation for env variables. + // Previous canSplit checks ensure this can never be -1. + // TODO: I haven't brought over canSplit; make sure this doesn't break + splitPos := t.splitPos(fpath) + + // Request has the extension; path was split successfully + docURI := fpath[:splitPos+len(t.SplitPath)] + pathInfo := fpath[splitPos+len(t.SplitPath):] + scriptName := fpath + + // Strip PATH_INFO from SCRIPT_NAME + scriptName = strings.TrimSuffix(scriptName, pathInfo) + + // SCRIPT_FILENAME is the absolute path of SCRIPT_NAME + scriptFilename := filepath.Join(t.Root, scriptName) + + // Add vhost path prefix to scriptName. Otherwise, some PHP software will + // have difficulty discovering its URL. + pathPrefix, _ := r.Context().Value(caddy.CtxKey("path_prefix")).(string) + scriptName = path.Join(pathPrefix, scriptName) + + // TODO: Disabled for now + // // Get the request URI from context. The context stores the original URI in case + // // it was changed by a middleware such as rewrite. By default, we pass the + // // original URI in as the value of REQUEST_URI (the user can overwrite this + // // if desired). Most PHP apps seem to want the original URI. Besides, this is + // // how nginx defaults: http://stackoverflow.com/a/12485156/1048862 + // reqURL, _ := r.Context().Value(httpserver.OriginalURLCtxKey).(url.URL) + + // // Retrieve name of remote user that was set by some downstream middleware such as basicauth. + // remoteUser, _ := r.Context().Value(httpserver.RemoteUserCtxKey).(string) + + requestScheme := "http" + if r.TLS != nil { + requestScheme = "https" + } + + // Some variables are unused but cleared explicitly to prevent + // the parent environment from interfering. + env = map[string]string{ + // Variables defined in CGI 1.1 spec + "AUTH_TYPE": "", // Not used + "CONTENT_LENGTH": r.Header.Get("Content-Length"), + "CONTENT_TYPE": r.Header.Get("Content-Type"), + "GATEWAY_INTERFACE": "CGI/1.1", + "PATH_INFO": pathInfo, + "QUERY_STRING": r.URL.RawQuery, + "REMOTE_ADDR": ip, + "REMOTE_HOST": ip, // For speed, remote host lookups disabled + "REMOTE_PORT": port, + "REMOTE_IDENT": "", // Not used + // "REMOTE_USER": remoteUser, // TODO: + "REQUEST_METHOD": r.Method, + "REQUEST_SCHEME": requestScheme, + "SERVER_NAME": t.ServerName, + "SERVER_PORT": t.ServerPort, + "SERVER_PROTOCOL": r.Proto, + "SERVER_SOFTWARE": t.SoftwareName + "/" + t.SoftwareVersion, + + // Other variables + // "DOCUMENT_ROOT": rule.Root, + "DOCUMENT_URI": docURI, + "HTTP_HOST": r.Host, // added here, since not always part of headers + // "REQUEST_URI": reqURL.RequestURI(), // TODO: + "SCRIPT_FILENAME": scriptFilename, + "SCRIPT_NAME": scriptName, + } + + // compliance with the CGI specification requires that + // PATH_TRANSLATED should only exist if PATH_INFO is defined. + // Info: https://www.ietf.org/rfc/rfc3875 Page 14 + if env["PATH_INFO"] != "" { + env["PATH_TRANSLATED"] = filepath.Join(t.Root, pathInfo) // Info: http://www.oreilly.com/openbook/cgi/ch02_04.html + } + + // Some web apps rely on knowing HTTPS or not + if r.TLS != nil { + env["HTTPS"] = "on" + // and pass the protocol details in a manner compatible with apache's mod_ssl + // (which is why these have a SSL_ prefix and not TLS_). + v, ok := tlsProtocolStrings[r.TLS.Version] + if ok { + env["SSL_PROTOCOL"] = v + } + // and pass the cipher suite in a manner compatible with apache's mod_ssl + for k, v := range caddytls.SupportedCipherSuites { + if v == r.TLS.CipherSuite { + env["SSL_CIPHER"] = k + break + } + } + } + + // Add env variables from config (with support for placeholders in values) + repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) + for _, envVar := range t.EnvVars { + env[envVar[0]] = repl.ReplaceAll(envVar[1], "") + } + + // Add all HTTP headers to env variables + for field, val := range r.Header { + header := strings.ToUpper(field) + header = headerNameReplacer.Replace(header) + env["HTTP_"+header] = strings.Join(val, ", ") + } + return env, nil +} + +// splitPos returns the index where path should +// be split based on t.SplitPath. +func (t Transport) splitPos(path string) int { + // TODO: + // if httpserver.CaseSensitivePath { + // return strings.Index(path, r.SplitPath) + // } + return strings.Index(strings.ToLower(path), strings.ToLower(t.SplitPath)) +} + +// TODO: +// Map of supported protocols to Apache ssl_mod format +// Note that these are slightly different from SupportedProtocols in caddytls/config.go +var tlsProtocolStrings = map[uint16]string{ + tls.VersionTLS10: "TLSv1", + tls.VersionTLS11: "TLSv1.1", + tls.VersionTLS12: "TLSv1.2", + tls.VersionTLS13: "TLSv1.3", +} + +var headerNameReplacer = strings.NewReplacer(" ", "_", "-", "_") diff --git a/modules/caddyhttp/reverseproxy/healthchecker.go b/modules/caddyhttp/reverseproxy/healthchecker.go deleted file mode 100755 index c557d3f..0000000 --- a/modules/caddyhttp/reverseproxy/healthchecker.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2015 Matthew Holt and The Caddy Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package reverseproxy - -import ( - "net/http" - "time" -) - -// Upstream represents the interface that must be satisfied to use the healthchecker. -type Upstream interface { - SetHealthiness(bool) -} - -// HealthChecker represents a worker that periodically evaluates if proxy upstream host is healthy. -type HealthChecker struct { - upstream Upstream - Ticker *time.Ticker - HTTPClient *http.Client - StopChan chan bool -} - -// ScheduleChecks periodically runs health checks against an upstream host. -func (h *HealthChecker) ScheduleChecks(url string) { - // check if a host is healthy on start vs waiting for timer - h.upstream.SetHealthiness(h.IsHealthy(url)) - stop := make(chan bool) - h.StopChan = stop - - go func() { - for { - select { - case <-h.Ticker.C: - h.upstream.SetHealthiness(h.IsHealthy(url)) - case <-stop: - return - } - } - }() -} - -// Stop stops the healthchecker from makeing further requests. -func (h *HealthChecker) Stop() { - h.Ticker.Stop() - close(h.StopChan) -} - -// IsHealthy attempts to check if a upstream host is healthy. -func (h *HealthChecker) IsHealthy(url string) bool { - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return false - } - - resp, err := h.HTTPClient.Do(req) - if err != nil { - return false - } - - if resp.StatusCode < 200 || resp.StatusCode >= 400 { - return false - } - - return true -} - -// NewHealthCheckWorker returns a new instance of a HealthChecker. -func NewHealthCheckWorker(u Upstream, interval time.Duration, client *http.Client) *HealthChecker { - return &HealthChecker{ - upstream: u, - Ticker: time.NewTicker(interval), - HTTPClient: client, - } -} diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go new file mode 100644 index 0000000..36dd776 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -0,0 +1,133 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import ( + "net" + "net/http" + "time" + + "github.com/caddyserver/caddy/v2" +) + +func init() { + caddy.RegisterModule(HTTPTransport{}) +} + +// TODO: This is the default transport, basically just http.Transport, but we define JSON struct tags... +type HTTPTransport struct { + // TODO: Actually this is where the TLS config should go, technically... + // as well as keepalives and dial timeouts... + // TODO: It's possible that other transports (like fastcgi) might be + // able to borrow/use at least some of these config fields; if so, + // move them into a type called CommonTransport and embed it + + TLS *TLSConfig `json:"tls,omitempty"` + KeepAlive *KeepAlive `json:"keep_alive,omitempty"` + Compression *bool `json:"compression,omitempty"` + MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // TODO: NOTE: we use our health check stuff to enforce max REQUESTS per host, but this is connections + DialTimeout caddy.Duration `json:"dial_timeout,omitempty"` + FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"` + ResponseHeaderTimeout caddy.Duration `json:"response_header_timeout,omitempty"` + ExpectContinueTimeout caddy.Duration `json:"expect_continue_timeout,omitempty"` + MaxResponseHeaderSize int64 `json:"max_response_header_size,omitempty"` + WriteBufferSize int `json:"write_buffer_size,omitempty"` + ReadBufferSize int `json:"read_buffer_size,omitempty"` + // TODO: ProxyConnectHeader? + + RoundTripper http.RoundTripper `json:"-"` +} + +// CaddyModule returns the Caddy module information. +func (HTTPTransport) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.transport.http", + New: func() caddy.Module { return new(HTTPTransport) }, + } +} + +func (h *HTTPTransport) Provision(ctx caddy.Context) error { + dialer := &net.Dialer{ + Timeout: time.Duration(h.DialTimeout), + FallbackDelay: time.Duration(h.FallbackDelay), + // TODO: Resolver + } + rt := &http.Transport{ + DialContext: dialer.DialContext, + MaxConnsPerHost: h.MaxConnsPerHost, + ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout), + ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout), + MaxResponseHeaderBytes: h.MaxResponseHeaderSize, + WriteBufferSize: h.WriteBufferSize, + ReadBufferSize: h.ReadBufferSize, + } + + if h.TLS != nil { + rt.TLSHandshakeTimeout = time.Duration(h.TLS.HandshakeTimeout) + // TODO: rest of TLS config + } + + if h.KeepAlive != nil { + dialer.KeepAlive = time.Duration(h.KeepAlive.ProbeInterval) + + if enabled := h.KeepAlive.Enabled; enabled != nil { + rt.DisableKeepAlives = !*enabled + } + rt.MaxIdleConns = h.KeepAlive.MaxIdleConns + rt.MaxIdleConnsPerHost = h.KeepAlive.MaxIdleConnsPerHost + rt.IdleConnTimeout = time.Duration(h.KeepAlive.IdleConnTimeout) + } + + if h.Compression != nil { + rt.DisableCompression = !*h.Compression + } + + h.RoundTripper = rt + + return nil +} + +func (h HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return h.RoundTripper.RoundTrip(req) +} + +type TLSConfig struct { + CAPool []string `json:"ca_pool,omitempty"` + ClientCertificate string `json:"client_certificate,omitempty"` + InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` + HandshakeTimeout caddy.Duration `json:"handshake_timeout,omitempty"` +} + +type KeepAlive struct { + Enabled *bool `json:"enabled,omitempty"` + ProbeInterval caddy.Duration `json:"probe_interval,omitempty"` + MaxIdleConns int `json:"max_idle_conns,omitempty"` + MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"` + IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle +} + +var ( + defaultDialer = net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + } + + // TODO: does this need to be configured to enable HTTP/2? + defaultTransport = &http.Transport{ + DialContext: defaultDialer.DialContext, + TLSHandshakeTimeout: 5 * time.Second, + IdleConnTimeout: 2 * time.Minute, + } +) diff --git a/modules/caddyhttp/reverseproxy/module.go b/modules/caddyhttp/reverseproxy/module.go deleted file mode 100755 index 21aca1d..0000000 --- a/modules/caddyhttp/reverseproxy/module.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2015 Matthew Holt and The Caddy Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package reverseproxy - -import ( - "github.com/caddyserver/caddy/v2" - "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" - "github.com/caddyserver/caddy/v2/modules/caddyhttp" -) - -func init() { - caddy.RegisterModule(new(LoadBalanced)) - httpcaddyfile.RegisterHandlerDirective("reverse_proxy", parseCaddyfile) // TODO: "proxy"? -} - -// CaddyModule returns the Caddy module information. -func (*LoadBalanced) CaddyModule() caddy.ModuleInfo { - return caddy.ModuleInfo{ - Name: "http.handlers.reverse_proxy", - New: func() caddy.Module { return new(LoadBalanced) }, - } -} - -// parseCaddyfile sets up the handler from Caddyfile tokens. Syntax: -// -// proxy [] -// -// TODO: This needs to be finished. It definitely needs to be able to open a block... -func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) { - lb := new(LoadBalanced) - for h.Next() { - allTo := h.RemainingArgs() - if len(allTo) == 0 { - return nil, h.ArgErr() - } - for _, to := range allTo { - lb.Upstreams = append(lb.Upstreams, &UpstreamConfig{Host: to}) - } - } - return lb, nil -} diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go old mode 100755 new mode 100644 index 68393de..e312d71 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -15,7 +15,9 @@ package reverseproxy import ( + "bytes" "context" + "encoding/json" "fmt" "io" "log" @@ -24,219 +26,229 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "time" + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" "golang.org/x/net/http/httpguts" ) -// ReverseProxy is an HTTP Handler that takes an incoming request and -// sends it to another server, proxying the response back to the -// client. -type ReverseProxy struct { - // Director must be a function which modifies - // the request into a new request to be sent - // using Transport. Its response is then copied - // back to the original client unmodified. - // Director must not access the provided Request - // after returning. - Director func(*http.Request) - - // The transport used to perform proxy requests. - // If nil, http.DefaultTransport is used. - Transport http.RoundTripper - - // FlushInterval specifies the flush interval - // to flush to the client while copying the - // response body. - // If zero, no periodic flushing is done. - // A negative value means to flush immediately - // after each write to the client. - // The FlushInterval is ignored when ReverseProxy - // recognizes a response as a streaming response; - // for such responses, writes are flushed to the client - // immediately. - FlushInterval time.Duration - - // ErrorLog specifies an optional logger for errors - // that occur when attempting to proxy the request. - // If nil, logging goes to os.Stderr via the log package's - // standard logger. - ErrorLog *log.Logger - - // BufferPool optionally specifies a buffer pool to - // get byte slices for use by io.CopyBuffer when - // copying HTTP response bodies. - BufferPool BufferPool - - // ModifyResponse is an optional function that modifies the - // Response from the backend. It is called if the backend - // returns a response at all, with any HTTP status code. - // If the backend is unreachable, the optional ErrorHandler is - // called without any call to ModifyResponse. - // - // If ModifyResponse returns an error, ErrorHandler is called - // with its error value. If ErrorHandler is nil, its default - // implementation is used. - ModifyResponse func(*http.Response) error - - // ErrorHandler is an optional function that handles errors - // reaching the backend or errors from ModifyResponse. - // - // If nil, the default is to log the provided error and return - // a 502 Status Bad Gateway response. - ErrorHandler func(http.ResponseWriter, *http.Request, error) -} - -// A BufferPool is an interface for getting and returning temporary -// byte slices for use by io.CopyBuffer. -type BufferPool interface { - Get() []byte - Put([]byte) +func init() { + caddy.RegisterModule(Handler{}) } -func singleJoiningSlash(a, b string) string { - aslash := strings.HasSuffix(a, "/") - bslash := strings.HasPrefix(b, "/") - switch { - case aslash && bslash: - return a + b[1:] - case !aslash && !bslash: - return a + "/" + b +type Handler struct { + TransportRaw json.RawMessage `json:"transport,omitempty"` + LoadBalancing *LoadBalancing `json:"load_balancing,omitempty"` + HealthChecks *HealthChecks `json:"health_checks,omitempty"` + // UpstreamStorageRaw json.RawMessage `json:"upstream_storage,omitempty"` // TODO: + Upstreams HostPool `json:"upstreams,omitempty"` + + // UpstreamProvider UpstreamProvider `json:"-"` // TODO: + Transport http.RoundTripper `json:"-"` +} + +// CaddyModule returns the Caddy module information. +func (Handler) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy", + New: func() caddy.Module { return new(Handler) }, } - return a + b } -// NewSingleHostReverseProxy returns a new ReverseProxy that routes -// URLs to the scheme, host, and base path provided in target. If the -// target's path is "/base" and the incoming request was for "/dir", -// the target request will be for /base/dir. -// NewSingleHostReverseProxy does not rewrite the Host header. -// To rewrite Host headers, use ReverseProxy directly with a custom -// Director policy. -func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { - targetQuery := target.RawQuery - director := func(req *http.Request) { - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host - req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) - if targetQuery == "" || req.URL.RawQuery == "" { - req.URL.RawQuery = targetQuery + req.URL.RawQuery - } else { - req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery +func (h *Handler) Provision(ctx caddy.Context) error { + if h.TransportRaw != nil { + val, err := ctx.LoadModuleInline("protocol", "http.handlers.reverse_proxy.transport", h.TransportRaw) + if err != nil { + return fmt.Errorf("loading transport module: %s", err) } - if _, ok := req.Header["User-Agent"]; !ok { - // explicitly disable User-Agent so it's not set to default value - req.Header.Set("User-Agent", "") + h.Transport = val.(http.RoundTripper) + h.TransportRaw = nil // allow GC to deallocate - TODO: Does this help? + } + if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil { + val, err := ctx.LoadModuleInline("policy", + "http.handlers.reverse_proxy.selection_policies", + h.LoadBalancing.SelectionPolicyRaw) + if err != nil { + return fmt.Errorf("loading load balancing selection module: %s", err) } + h.LoadBalancing.SelectionPolicy = val.(Selector) + h.LoadBalancing.SelectionPolicyRaw = nil // allow GC to deallocate - TODO: Does this help? } - return &ReverseProxy{Director: director} -} -func copyHeader(dst, src http.Header) { - for k, vv := range src { - for _, v := range vv { - dst.Add(k, v) - } + if h.Transport == nil { + h.Transport = defaultTransport } -} -func cloneHeader(h http.Header) http.Header { - h2 := make(http.Header, len(h)) - for k, vv := range h { - vv2 := make([]string, len(vv)) - copy(vv2, vv) - h2[k] = vv2 + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + if h.LoadBalancing.SelectionPolicy == nil { + h.LoadBalancing.SelectionPolicy = RandomSelection{} + } + if h.LoadBalancing.TryDuration > 0 && h.LoadBalancing.TryInterval == 0 { + // a non-zero try_duration with a zero try_interval + // will always spin the CPU for try_duration if the + // upstream is local or low-latency; default to some + // sane waiting period before try attempts + h.LoadBalancing.TryInterval = caddy.Duration(250 * time.Millisecond) } - return h2 -} -// Hop-by-hop headers. These are removed when sent to the backend. -// As of RFC 7230, hop-by-hop headers are required to appear in the -// Connection header field. These are the headers defined by the -// obsoleted RFC 2616 (section 13.5.1) and are used for backward -// compatibility. -var hopHeaders = []string{ - "Connection", - "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Authorization", - "Te", // canonicalized version of "TE" - "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 - "Transfer-Encoding", - "Upgrade", -} + for _, upstream := range h.Upstreams { + // url parser requires a scheme + if !strings.Contains(upstream.Address, "://") { + upstream.Address = "http://" + upstream.Address + } + u, err := url.Parse(upstream.Address) + if err != nil { + return fmt.Errorf("invalid upstream address %s: %v", upstream.Address, err) + } + upstream.hostURL = u + + // if host already exists from a current config, + // use that instead; otherwise, add it + // TODO: make hosts modular, so that their state can be distributed in enterprise for example + // TODO: If distributed, the pool should be stored in storage... + var host Host = new(upstreamHost) + activeHost, loaded := hosts.LoadOrStore(u.String(), host) + if loaded { + host = activeHost.(Host) + } + upstream.Host = host + + // if the passive health checker has a non-zero "unhealthy + // request count" but the upstream has no MaxRequests set + // (they are the same thing, but one is a default value for + // for upstreams with a zero MaxRequests), copy the default + // value into this upstream, since the value in the upstream + // is what is used during availability checks + if h.HealthChecks != nil && + h.HealthChecks.Passive != nil && + h.HealthChecks.Passive.UnhealthyRequestCount > 0 && + upstream.MaxRequests == 0 { + upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount + } -func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) { - p.logf("http: proxy error: %v", err) - rw.WriteHeader(http.StatusBadGateway) -} + // TODO: active health checks -func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) { - if p.ErrorHandler != nil { - return p.ErrorHandler + if h.HealthChecks != nil { + // upstreams need independent access to the passive + // health check policy so they can, you know, passively + // do health checks + upstream.healthCheckPolicy = h.HealthChecks.Passive + } } - return p.defaultErrorHandler + + return nil } -// modifyResponse conditionally runs the optional ModifyResponse hook -// and reports whether the request should proceed. -func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool { - if p.ModifyResponse == nil { - return true - } - if err := p.ModifyResponse(res); err != nil { - res.Body.Close() - p.getErrorHandler()(rw, req, err) - return false +func (h *Handler) Cleanup() error { + // TODO: finish this up, make sure it takes care of any active health checkers or whatever + for _, upstream := range h.Upstreams { + hosts.Delete(upstream.hostURL.String()) } - return true + return nil } -func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*http.Response, error) { - transport := p.Transport - if transport == nil { - transport = http.DefaultTransport +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { + // prepare the request for proxying; this is needed only once + err := h.prepareRequest(r) + if err != nil { + return caddyhttp.Error(http.StatusInternalServerError, + fmt.Errorf("preparing request for upstream round-trip: %v", err)) } - ctx := req.Context() - if cn, ok := rw.(http.CloseNotifier); ok { - var cancel context.CancelFunc - ctx, cancel = context.WithCancel(ctx) - defer cancel() - notifyChan := cn.CloseNotify() - go func() { - select { - case <-notifyChan: - cancel() - case <-ctx.Done(): + start := time.Now() + + var proxyErr error + for { + // choose an available upstream + upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, r) + if upstream == nil { + if proxyErr == nil { + proxyErr = fmt.Errorf("no available upstreams") + } + if !h.tryAgain(start, proxyErr) { + break } - }() + continue + } + + // proxy the request to that upstream + proxyErr = h.reverseProxy(w, r, upstream) + if proxyErr == nil { + return nil + } + + // remember this failure (if enabled) + h.countFailure(upstream) + + // if we've tried long enough, break + if !h.tryAgain(start, proxyErr) { + break + } } - outreq := req.WithContext(ctx) // includes shallow copies of maps, but okay + return caddyhttp.Error(http.StatusBadGateway, proxyErr) +} + +// prepareRequest modifies req so that it is ready to be proxied, +// except for directing to a specific upstream. This method mutates +// headers and other necessary properties of the request and should +// be done just once (before proxying) regardless of proxy retries. +// This assumes that no mutations of the request are performed +// by h during or after proxying. +func (h Handler) prepareRequest(req *http.Request) error { + // ctx := req.Context() + // TODO: do we need to support CloseNotifier? It was deprecated years ago. + // All this does is wrap CloseNotify with context cancel, for those responsewriters + // which didn't support context, but all the ones we'd use should nowadays, right? + // if cn, ok := rw.(http.CloseNotifier); ok { + // var cancel context.CancelFunc + // ctx, cancel = context.WithCancel(ctx) + // defer cancel() + // notifyChan := cn.CloseNotify() + // go func() { + // select { + // case <-notifyChan: + // cancel() + // case <-ctx.Done(): + // } + // }() + // } + + // TODO: do we need to call WithContext, since we won't be changing req.Context() above if we remove the CloseNotifier stuff? + // TODO: (This is where references to req were originally "outreq", a shallow clone, which I think is unnecessary in our case) + // req = req.WithContext(ctx) // includes shallow copies of maps, but okay if req.ContentLength == 0 { - outreq.Body = nil // Issue 16036: nil Body for http.Transport retries + req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries } - outreq.Header = cloneHeader(req.Header) + // TODO: is this needed? + // req.Header = cloneHeader(req.Header) - p.Director(outreq) - outreq.Close = false + req.Close = false - reqUpType := upgradeType(outreq.Header) - removeConnectionHeaders(outreq.Header) + // if User-Agent is not set by client, then explicitly + // disable it so it's not set to default value by std lib + if _, ok := req.Header["User-Agent"]; !ok { + req.Header.Set("User-Agent", "") + } + + reqUpType := upgradeType(req.Header) + removeConnectionHeaders(req.Header) // Remove hop-by-hop headers to the backend. Especially // important is "Connection" because we want a persistent // connection, regardless of what the client sent to us. for _, h := range hopHeaders { - hv := outreq.Header.Get(h) + hv := req.Header.Get(h) if hv == "" { continue } if h == "Te" && hv == "trailers" { - // Issue 21096: tell backend applications that + // Issue golang/go#21096: tell backend applications that // care about trailer support that we support // trailers. (We do, but we don't go out of // our way to advertise that unless the @@ -244,40 +256,65 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht // worth mentioning) continue } - outreq.Header.Del(h) + req.Header.Del(h) } // After stripping all the hop-by-hop connection headers above, add back any // necessary for protocol upgrades, such as for websockets. if reqUpType != "" { - outreq.Header.Set("Connection", "Upgrade") - outreq.Header.Set("Upgrade", reqUpType) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", reqUpType) } if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { // If we aren't the first proxy retain prior // X-Forwarded-For information as a comma+space // separated list and fold multiple headers into one. - if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + if prior, ok := req.Header["X-Forwarded-For"]; ok { clientIP = strings.Join(prior, ", ") + ", " + clientIP } - outreq.Header.Set("X-Forwarded-For", clientIP) + req.Header.Set("X-Forwarded-For", clientIP) } - res, err := transport.RoundTrip(outreq) + return nil +} + +// TODO: +// this code is the entry point to what was borrowed from the net/http/httputil package in the standard library. +func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, upstream *Upstream) error { + // TODO: count this active request + + // point the request to this upstream + h.directRequest(req, upstream) + + // do the round-trip + start := time.Now() + res, err := h.Transport.RoundTrip(req) + latency := time.Since(start) if err != nil { - p.getErrorHandler()(rw, outreq, err) - return nil, err + return err } - // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) - if res.StatusCode == http.StatusSwitchingProtocols { - if !p.modifyResponse(rw, res, outreq) { - return res, nil + // perform passive health checks (if enabled) + if h.HealthChecks != nil && h.HealthChecks.Passive != nil { + // strike if the status code matches one that is "bad" + for _, badStatus := range h.HealthChecks.Passive.UnhealthyStatus { + if caddyhttp.StatusCodeMatches(res.StatusCode, badStatus) { + h.countFailure(upstream) + } } - p.handleUpgradeResponse(rw, outreq, res) - return res, nil + // strike if the roundtrip took too long + if h.HealthChecks.Passive.UnhealthyLatency > 0 && + latency >= time.Duration(h.HealthChecks.Passive.UnhealthyLatency) { + h.countFailure(upstream) + } + } + + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) + if res.StatusCode == http.StatusSwitchingProtocols { + h.handleUpgradeResponse(rw, req, res) + return nil } removeConnectionHeaders(res.Header) @@ -286,10 +323,6 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht res.Header.Del(h) } - if !p.modifyResponse(rw, res, outreq) { - return res, nil - } - copyHeader(rw.Header(), res.Header) // The "Trailer" header isn't included in the Transport's response, @@ -305,15 +338,16 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht rw.WriteHeader(res.StatusCode) - err = p.copyResponse(rw, res.Body, p.flushInterval(req, res)) + err = h.copyResponse(rw, res.Body, h.flushInterval(req, res)) if err != nil { defer res.Body.Close() // Since we're streaming the response, if we run into an error all we can do - // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler + // is abort the request. Issue golang/go#23643: ReverseProxy should use ErrAbortHandler // on read error while copying body. + // TODO: Look into whether we want to panic at all in our case... if !shouldPanicOnCopyError(req) { - p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) - return nil, err + // p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) + return err } panic(http.ErrAbortHandler) @@ -331,7 +365,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht if len(res.Trailer) == announcedTrailers { copyHeader(rw.Header(), res.Trailer) - return res, nil + return nil } for k, vv := range res.Trailer { @@ -341,46 +375,90 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*ht } } - return res, nil + return nil } -var inOurTests bool // whether we're in our own tests - -// shouldPanicOnCopyError reports whether the reverse proxy should -// panic with http.ErrAbortHandler. This is the right thing to do by -// default, but Go 1.10 and earlier did not, so existing unit tests -// weren't expecting panics. Only panic in our own tests, or when -// running under the HTTP server. -func shouldPanicOnCopyError(req *http.Request) bool { - if inOurTests { - // Our tests know to handle this panic. - return true +// tryAgain takes the time that the handler was initially invoked +// as well as any error currently obtained and returns true if +// another attempt should be made at proxying the request. If +// true is returned, it has already blocked long enough before +// the next retry (i.e. no more sleeping is needed). If false is +// returned, the handler should stop trying to proxy the request. +func (h Handler) tryAgain(start time.Time, proxyErr error) bool { + // if downstream has canceled the request, break + if proxyErr == context.Canceled { + return false } - if req.Context().Value(http.ServerContextKey) != nil { - // We seem to be running under an HTTP server, so - // it'll recover the panic. - return true + // if we've tried long enough, break + if time.Since(start) >= time.Duration(h.LoadBalancing.TryDuration) { + return false } - // Otherwise act like Go 1.10 and earlier to not break - // existing tests. - return false + // otherwise, wait and try the next available host + time.Sleep(time.Duration(h.LoadBalancing.TryInterval)) + return true } -// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. -// See RFC 7230, section 6.1 -func removeConnectionHeaders(h http.Header) { - if c := h.Get("Connection"); c != "" { - for _, f := range strings.Split(c, ",") { - if f = strings.TrimSpace(f); f != "" { - h.Del(f) - } - } +// directRequest modifies only req.URL so that it points to the +// given upstream host. It must modify ONLY the request URL. +func (h Handler) directRequest(req *http.Request, upstream *Upstream) { + target := upstream.hostURL + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) // TODO: This might be a bug (if any part of the path was augmented from a previously-tried upstream; need to start from clean original path of request, same for query string!) + if target.RawQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = target.RawQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = target.RawQuery + "&" + req.URL.RawQuery } } +func (h Handler) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { + reqUpType := upgradeType(req.Header) + resUpType := upgradeType(res.Header) + if reqUpType != resUpType { + // p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) + return + } + + copyHeader(res.Header, rw.Header()) + + hj, ok := rw.(http.Hijacker) + if !ok { + // p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + // p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) + return + } + defer backConn.Close() + conn, brw, err := hj.Hijack() + if err != nil { + // p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) + return + } + defer conn.Close() + res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above + if err := res.Write(brw); err != nil { + // p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) + return + } + if err := brw.Flush(); err != nil { + // p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) + return + } + errc := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + <-errc + return +} + // flushInterval returns the p.FlushInterval value, conditionally // overriding its value for a specific request/response. -func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration { +func (h Handler) flushInterval(req *http.Request, res *http.Response) time.Duration { resCT := res.Header.Get("Content-Type") // For Server-Sent Events responses, flush immediately. @@ -390,10 +468,11 @@ func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time } // TODO: more specific cases? e.g. res.ContentLength == -1? - return p.FlushInterval + // return h.FlushInterval + return 0 } -func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { +func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { if flushInterval != 0 { if wf, ok := dst.(writeFlusher); ok { mlw := &maxLatencyWriter{ @@ -405,18 +484,25 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval } } + // TODO: Figure out how we want to do this... using custom buffer pool type seems unnecessary + // or maybe it is, depending on how we want to handle errors, + // see: https://github.com/golang/go/issues/21814 + // buf := bufPool.Get().(*bytes.Buffer) + // buf.Reset() + // defer bufPool.Put(buf) + // _, err := io.CopyBuffer(dst, src, ) var buf []byte - if p.BufferPool != nil { - buf = p.BufferPool.Get() - defer p.BufferPool.Put(buf) - } - _, err := p.copyBuffer(dst, src, buf) + // if h.BufferPool != nil { + // buf = h.BufferPool.Get() + // defer h.BufferPool.Put(buf) + // } + _, err := h.copyBuffer(dst, src, buf) return err } // copyBuffer returns any write errors or non-EOF read errors, and the amount // of bytes written. -func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { +func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { if len(buf) == 0 { buf = make([]byte, 32*1024) } @@ -424,7 +510,13 @@ func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int for { nr, rerr := src.Read(buf) if rerr != nil && rerr != io.EOF && rerr != context.Canceled { - p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) + // TODO: this could be useful to know (indeed, it revealed an error in our + // fastcgi PoC earlier; but it's this single error report here that necessitates + // a function separate from io.CopyBuffer, since io.CopyBuffer does not distinguish + // between read or write errors; in a reverse proxy situation, write errors are not + // something we need to report to the client, but read errors are a problem on our + // end for sure. so we need to decide what we want.) + // p.logf("copyBuffer: ReverseProxy read error during body copy: %v", rerr) } if nr > 0 { nw, werr := dst.Write(buf[:nr]) @@ -447,12 +539,36 @@ func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int } } -func (p *ReverseProxy) logf(format string, args ...interface{}) { - if p.ErrorLog != nil { - p.ErrorLog.Printf(format, args...) - } else { - log.Printf(format, args...) +// countFailure remembers 1 failure for upstream for the +// configured duration. If passive health checks are +// disabled or failure expiry is 0, this is a no-op. +func (h Handler) countFailure(upstream *Upstream) { + // only count failures if passive health checking is enabled + // and if failures are configured have a non-zero expiry + if h.HealthChecks == nil || h.HealthChecks.Passive == nil { + return + } + failDuration := time.Duration(h.HealthChecks.Passive.FailDuration) + if failDuration == 0 { + return + } + + // count failure immediately + err := upstream.Host.CountFail(1) + if err != nil { + log.Printf("[ERROR] proxy: upstream %s: counting failure: %v", + upstream.hostURL, err) } + + // forget it later + go func(host Host, failDuration time.Duration) { + time.Sleep(failDuration) + err := host.CountFail(-1) + if err != nil { + log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v", + upstream.hostURL, err) + } + }(upstream.Host, failDuration) } type writeFlusher interface { @@ -508,6 +624,61 @@ func (m *maxLatencyWriter) stop() { } } +// switchProtocolCopier exists so goroutines proxying data back and +// forth have nice names in stacks. +type switchProtocolCopier struct { + user, backend io.ReadWriter +} + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + _, err := io.Copy(c.user, c.backend) + errc <- err +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + _, err := io.Copy(c.backend, c.user) + errc <- err +} + +// shouldPanicOnCopyError reports whether the reverse proxy should +// panic with http.ErrAbortHandler. This is the right thing to do by +// default, but Go 1.10 and earlier did not, so existing unit tests +// weren't expecting panics. Only panic in our own tests, or when +// running under the HTTP server. +// TODO: I don't know if we want this at all... +func shouldPanicOnCopyError(req *http.Request) bool { + // if inOurTests { + // // Our tests know to handle this panic. + // return true + // } + if req.Context().Value(http.ServerContextKey) != nil { + // We seem to be running under an HTTP server, so + // it'll recover the panic. + return true + } + // Otherwise act like Go 1.10 and earlier to not break + // existing tests. + return false +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func cloneHeader(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + func upgradeType(h http.Header) string { if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { return "" @@ -515,62 +686,177 @@ func upgradeType(h http.Header) string { return strings.ToLower(h.Get("Upgrade")) } -func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { - reqUpType := upgradeType(req.Header) - resUpType := upgradeType(res.Header) - if reqUpType != resUpType { - p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) - return +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b } + return a + b +} - copyHeader(res.Header, rw.Header()) - - hj, ok := rw.(http.Hijacker) - if !ok { - p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) - return - } - backConn, ok := res.Body.(io.ReadWriteCloser) - if !ok { - p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) - return +// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. +// See RFC 7230, section 6.1 +func removeConnectionHeaders(h http.Header) { + if c := h.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + h.Del(f) + } + } } - defer backConn.Close() - conn, brw, err := hj.Hijack() - if err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) - return +} + +type LoadBalancing struct { + SelectionPolicyRaw json.RawMessage `json:"selection_policy,omitempty"` + TryDuration caddy.Duration `json:"try_duration,omitempty"` + TryInterval caddy.Duration `json:"try_interval,omitempty"` + + SelectionPolicy Selector `json:"-"` +} + +type Selector interface { + Select(HostPool, *http.Request) *Upstream +} + +type HealthChecks struct { + Active *ActiveHealthChecks `json:"active,omitempty"` + Passive *PassiveHealthChecks `json:"passive,omitempty"` +} + +type ActiveHealthChecks struct { + Path string `json:"path,omitempty"` + Port int `json:"port,omitempty"` + Interval caddy.Duration `json:"interval,omitempty"` + Timeout caddy.Duration `json:"timeout,omitempty"` + MaxSize int `json:"max_size,omitempty"` + ExpectStatus int `json:"expect_status,omitempty"` + ExpectBody string `json:"expect_body,omitempty"` +} + +type PassiveHealthChecks struct { + MaxFails int `json:"max_fails,omitempty"` + FailDuration caddy.Duration `json:"fail_duration,omitempty"` + UnhealthyRequestCount int `json:"unhealthy_request_count,omitempty"` + UnhealthyStatus []int `json:"unhealthy_status,omitempty"` + UnhealthyLatency caddy.Duration `json:"unhealthy_latency,omitempty"` +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", +} + +var bufPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + +////////////////////////////////// +// TODO: + +type Host interface { + NumRequests() int + Fails() int + Unhealthy() bool + + CountRequest(int) error + CountFail(int) error +} + +type HostPool []*Upstream + +type upstreamHost struct { + numRequests int64 // must be first field to be 64-bit aligned on 32-bit systems (see https://golang.org/pkg/sync/atomic/#pkg-note-BUG) + fails int64 + unhealthy int32 +} + +func (uh upstreamHost) NumRequests() int { + return int(atomic.LoadInt64(&uh.numRequests)) +} +func (uh upstreamHost) Fails() int { + return int(atomic.LoadInt64(&uh.fails)) +} +func (uh upstreamHost) Unhealthy() bool { + return atomic.LoadInt32(&uh.unhealthy) == 1 +} +func (uh *upstreamHost) CountRequest(delta int) error { + result := atomic.AddInt64(&uh.numRequests, int64(delta)) + if result < 0 { + return fmt.Errorf("count below 0: %d", result) } - defer conn.Close() - res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above - if err := res.Write(brw); err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) - return + return nil +} +func (uh *upstreamHost) CountFail(delta int) error { + result := atomic.AddInt64(&uh.fails, int64(delta)) + if result < 0 { + return fmt.Errorf("count below 0: %d", result) } - if err := brw.Flush(); err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) - return + return nil +} + +type Upstream struct { + Host `json:"-"` + + Address string `json:"address,omitempty"` + MaxRequests int `json:"max_requests,omitempty"` + + // TODO: This could be really cool, to say that requests with + // certain headers or from certain IPs always go to this upstream + // HeaderAffinity string + // IPAffinity string + + healthCheckPolicy *PassiveHealthChecks + + hostURL *url.URL +} + +func (u Upstream) Available() bool { + return u.Healthy() && !u.Full() +} + +func (u Upstream) Healthy() bool { + healthy := !u.Host.Unhealthy() + if healthy && u.healthCheckPolicy != nil { + healthy = u.Host.Fails() < u.healthCheckPolicy.MaxFails } - errc := make(chan error, 1) - spc := switchProtocolCopier{user: conn, backend: backConn} - go spc.copyToBackend(errc) - go spc.copyFromBackend(errc) - <-errc - return + return healthy } -// switchProtocolCopier exists so goroutines proxying data back and -// forth have nice names in stacks. -type switchProtocolCopier struct { - user, backend io.ReadWriter +func (u Upstream) Full() bool { + return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests } -func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { - _, err := io.Copy(c.user, c.backend) - errc <- err +func (u Upstream) URL() *url.URL { + return u.hostURL } -func (c switchProtocolCopier) copyToBackend(errc chan<- error) { - _, err := io.Copy(c.backend, c.user) - errc <- err +var hosts = caddy.NewUsagePool() + +// TODO: ... +type UpstreamProvider interface { } + +// Interface guards +var ( + _ caddyhttp.MiddlewareHandler = (*Handler)(nil) + _ caddy.Provisioner = (*Handler)(nil) + _ caddy.CleanerUpper = (*Handler)(nil) +) diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go new file mode 100644 index 0000000..e0518c9 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go @@ -0,0 +1,351 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import ( + "fmt" + "hash/fnv" + weakrand "math/rand" + "net" + "net/http" + "sync/atomic" + "time" + + "github.com/caddyserver/caddy/v2" +) + +func init() { + caddy.RegisterModule(RandomSelection{}) + caddy.RegisterModule(RandomChoiceSelection{}) + caddy.RegisterModule(LeastConnSelection{}) + caddy.RegisterModule(RoundRobinSelection{}) + caddy.RegisterModule(FirstSelection{}) + caddy.RegisterModule(IPHashSelection{}) + caddy.RegisterModule(URIHashSelection{}) + caddy.RegisterModule(HeaderHashSelection{}) + + weakrand.Seed(time.Now().UTC().UnixNano()) +} + +// RandomSelection is a policy that selects +// an available host at random. +type RandomSelection struct{} + +// CaddyModule returns the Caddy module information. +func (RandomSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.random", + New: func() caddy.Module { return new(RandomSelection) }, + } +} + +// Select returns an available host, if any. +func (r RandomSelection) Select(pool HostPool, request *http.Request) *Upstream { + // use reservoir sampling because the number of available + // hosts isn't known: https://en.wikipedia.org/wiki/Reservoir_sampling + var randomHost *Upstream + var count int + for _, upstream := range pool { + if !upstream.Available() { + continue + } + // (n % 1 == 0) holds for all n, therefore a + // upstream will always be chosen if there is at + // least one available + count++ + if (weakrand.Int() % count) == 0 { + randomHost = upstream + } + } + return randomHost +} + +// RandomChoiceSelection is a policy that selects +// two or more available hosts at random, then +// chooses the one with the least load. +type RandomChoiceSelection struct { + Choose int `json:"choose,omitempty"` +} + +// CaddyModule returns the Caddy module information. +func (RandomChoiceSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.random_choice", + New: func() caddy.Module { return new(RandomChoiceSelection) }, + } +} + +func (r *RandomChoiceSelection) Provision(ctx caddy.Context) error { + if r.Choose == 0 { + r.Choose = 2 + } + return nil +} + +func (r RandomChoiceSelection) Validate() error { + if r.Choose < 2 { + return fmt.Errorf("choose must be at least 2") + } + return nil +} + +// Select returns an available host, if any. +func (r RandomChoiceSelection) Select(pool HostPool, _ *http.Request) *Upstream { + k := r.Choose + if k > len(pool) { + k = len(pool) + } + choices := make([]*Upstream, k) + for i, upstream := range pool { + if !upstream.Available() { + continue + } + j := weakrand.Intn(i) + if j < k { + choices[j] = upstream + } + } + return leastRequests(choices) +} + +// LeastConnSelection is a policy that selects the +// host with the least active requests. If multiple +// hosts have the same fewest number, one is chosen +// randomly. The term "conn" or "connection" is used +// in this policy name due to its similar meaning in +// other software, but our load balancer actually +// counts active requests rather than connections, +// since these days requests are multiplexed onto +// shared connections. +type LeastConnSelection struct{} + +// CaddyModule returns the Caddy module information. +func (LeastConnSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.least_conn", + New: func() caddy.Module { return new(LeastConnSelection) }, + } +} + +// Select selects the up host with the least number of connections in the +// pool. If more than one host has the same least number of connections, +// one of the hosts is chosen at random. +func (LeastConnSelection) Select(pool HostPool, _ *http.Request) *Upstream { + var bestHost *Upstream + var count int + var leastReqs int + + for _, host := range pool { + if !host.Available() { + continue + } + numReqs := host.NumRequests() + if numReqs < leastReqs { + leastReqs = numReqs + count = 0 + } + + // among hosts with same least connections, perform a reservoir + // sample: https://en.wikipedia.org/wiki/Reservoir_sampling + if numReqs == leastReqs { + count++ + if (weakrand.Int() % count) == 0 { + bestHost = host + } + } + } + + return bestHost +} + +// RoundRobinSelection is a policy that selects +// a host based on round-robin ordering. +type RoundRobinSelection struct { + robin uint32 +} + +// CaddyModule returns the Caddy module information. +func (RoundRobinSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.round_robin", + New: func() caddy.Module { return new(RoundRobinSelection) }, + } +} + +// Select returns an available host, if any. +func (r *RoundRobinSelection) Select(pool HostPool, _ *http.Request) *Upstream { + n := uint32(len(pool)) + if n == 0 { + return nil + } + for i := uint32(0); i < n; i++ { + atomic.AddUint32(&r.robin, 1) + host := pool[r.robin%n] + if host.Available() { + return host + } + } + return nil +} + +// FirstSelection is a policy that selects +// the first available host. +type FirstSelection struct{} + +// CaddyModule returns the Caddy module information. +func (FirstSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.first", + New: func() caddy.Module { return new(FirstSelection) }, + } +} + +// Select returns an available host, if any. +func (FirstSelection) Select(pool HostPool, _ *http.Request) *Upstream { + for _, host := range pool { + if host.Available() { + return host + } + } + return nil +} + +// IPHashSelection is a policy that selects a host +// based on hashing the remote IP of the request. +type IPHashSelection struct{} + +// CaddyModule returns the Caddy module information. +func (IPHashSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.ip_hash", + New: func() caddy.Module { return new(IPHashSelection) }, + } +} + +// Select returns an available host, if any. +func (IPHashSelection) Select(pool HostPool, req *http.Request) *Upstream { + clientIP, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + clientIP = req.RemoteAddr + } + return hostByHashing(pool, clientIP) +} + +// URIHashSelection is a policy that selects a +// host by hashing the request URI. +type URIHashSelection struct{} + +// CaddyModule returns the Caddy module information. +func (URIHashSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.uri_hash", + New: func() caddy.Module { return new(URIHashSelection) }, + } +} + +// Select returns an available host, if any. +func (URIHashSelection) Select(pool HostPool, req *http.Request) *Upstream { + return hostByHashing(pool, req.RequestURI) +} + +// HeaderHashSelection is a policy that selects +// a host based on a given request header. +type HeaderHashSelection struct { + Field string `json:"field,omitempty"` +} + +// CaddyModule returns the Caddy module information. +func (HeaderHashSelection) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + Name: "http.handlers.reverse_proxy.selection_policies.header", + New: func() caddy.Module { return new(HeaderHashSelection) }, + } +} + +// Select returns an available host, if any. +func (s HeaderHashSelection) Select(pool HostPool, req *http.Request) *Upstream { + if s.Field == "" { + return nil + } + val := req.Header.Get(s.Field) + if val == "" { + return RandomSelection{}.Select(pool, req) + } + return hostByHashing(pool, val) +} + +// leastRequests returns the host with the +// least number of active requests to it. +// If more than one host has the same +// least number of active requests, then +// one of those is chosen at random. +func leastRequests(upstreams []*Upstream) *Upstream { + if len(upstreams) == 0 { + return nil + } + var best []*Upstream + var bestReqs int + for _, upstream := range upstreams { + reqs := upstream.NumRequests() + if reqs == 0 { + return upstream + } + if reqs <= bestReqs { + bestReqs = reqs + best = append(best, upstream) + } + } + return best[weakrand.Intn(len(best))] +} + +// hostByHashing returns an available host +// from pool based on a hashable string s. +func hostByHashing(pool []*Upstream, s string) *Upstream { + poolLen := uint32(len(pool)) + if poolLen == 0 { + return nil + } + index := hash(s) % poolLen + for i := uint32(0); i < poolLen; i++ { + index += i + upstream := pool[index%poolLen] + if upstream.Available() { + return upstream + } + } + return nil +} + +// hash calculates a fast hash based on s. +func hash(s string) uint32 { + h := fnv.New32a() + h.Write([]byte(s)) + return h.Sum32() +} + +// Interface guards +var ( + _ Selector = (*RandomSelection)(nil) + _ Selector = (*RandomChoiceSelection)(nil) + _ Selector = (*LeastConnSelection)(nil) + _ Selector = (*RoundRobinSelection)(nil) + _ Selector = (*FirstSelection)(nil) + _ Selector = (*IPHashSelection)(nil) + _ Selector = (*URIHashSelection)(nil) + _ Selector = (*HeaderHashSelection)(nil) + + _ caddy.Validator = (*RandomChoiceSelection)(nil) + _ caddy.Provisioner = (*RandomChoiceSelection)(nil) +) diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go new file mode 100644 index 0000000..8006fb1 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go @@ -0,0 +1,363 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +// TODO: finish migrating these + +// import ( +// "net/http" +// "net/http/httptest" +// "os" +// "testing" +// ) + +// var workableServer *httptest.Server + +// func TestMain(m *testing.M) { +// workableServer = httptest.NewServer(http.HandlerFunc( +// func(w http.ResponseWriter, r *http.Request) { +// // do nothing +// })) +// r := m.Run() +// workableServer.Close() +// os.Exit(r) +// } + +// type customPolicy struct{} + +// func (customPolicy) Select(pool HostPool, _ *http.Request) Host { +// return pool[0] +// } + +// func testPool() HostPool { +// pool := []*UpstreamHost{ +// { +// Name: workableServer.URL, // this should resolve (healthcheck test) +// }, +// { +// Name: "http://localhost:99998", // this shouldn't +// }, +// { +// Name: "http://C", +// }, +// } +// return HostPool(pool) +// } + +// func TestRoundRobinPolicy(t *testing.T) { +// pool := testPool() +// rrPolicy := &RoundRobin{} +// request, _ := http.NewRequest("GET", "/", nil) + +// h := rrPolicy.Select(pool, request) +// // First selected host is 1, because counter starts at 0 +// // and increments before host is selected +// if h != pool[1] { +// t.Error("Expected first round robin host to be second host in the pool.") +// } +// h = rrPolicy.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected second round robin host to be third host in the pool.") +// } +// h = rrPolicy.Select(pool, request) +// if h != pool[0] { +// t.Error("Expected third round robin host to be first host in the pool.") +// } +// // mark host as down +// pool[1].Unhealthy = 1 +// h = rrPolicy.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected to skip down host.") +// } +// // mark host as up +// pool[1].Unhealthy = 0 + +// h = rrPolicy.Select(pool, request) +// if h == pool[2] { +// t.Error("Expected to balance evenly among healthy hosts") +// } +// // mark host as full +// pool[1].Conns = 1 +// pool[1].MaxConns = 1 +// h = rrPolicy.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected to skip full host.") +// } +// } + +// func TestLeastConnPolicy(t *testing.T) { +// pool := testPool() +// lcPolicy := &LeastConn{} +// request, _ := http.NewRequest("GET", "/", nil) + +// pool[0].Conns = 10 +// pool[1].Conns = 10 +// h := lcPolicy.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected least connection host to be third host.") +// } +// pool[2].Conns = 100 +// h = lcPolicy.Select(pool, request) +// if h != pool[0] && h != pool[1] { +// t.Error("Expected least connection host to be first or second host.") +// } +// } + +// func TestCustomPolicy(t *testing.T) { +// pool := testPool() +// customPolicy := &customPolicy{} +// request, _ := http.NewRequest("GET", "/", nil) + +// h := customPolicy.Select(pool, request) +// if h != pool[0] { +// t.Error("Expected custom policy host to be the first host.") +// } +// } + +// func TestIPHashPolicy(t *testing.T) { +// pool := testPool() +// ipHash := &IPHash{} +// request, _ := http.NewRequest("GET", "/", nil) +// // We should be able to predict where every request is routed. +// request.RemoteAddr = "172.0.0.1:80" +// h := ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } +// request.RemoteAddr = "172.0.0.2:80" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } +// request.RemoteAddr = "172.0.0.3:80" +// h = ipHash.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected ip hash policy host to be the third host.") +// } +// request.RemoteAddr = "172.0.0.4:80" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } + +// // we should get the same results without a port +// request.RemoteAddr = "172.0.0.1" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } +// request.RemoteAddr = "172.0.0.2" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } +// request.RemoteAddr = "172.0.0.3" +// h = ipHash.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected ip hash policy host to be the third host.") +// } +// request.RemoteAddr = "172.0.0.4" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } + +// // we should get a healthy host if the original host is unhealthy and a +// // healthy host is available +// request.RemoteAddr = "172.0.0.1" +// pool[1].Unhealthy = 1 +// h = ipHash.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected ip hash policy host to be the third host.") +// } + +// request.RemoteAddr = "172.0.0.2" +// h = ipHash.Select(pool, request) +// if h != pool[2] { +// t.Error("Expected ip hash policy host to be the third host.") +// } +// pool[1].Unhealthy = 0 + +// request.RemoteAddr = "172.0.0.3" +// pool[2].Unhealthy = 1 +// h = ipHash.Select(pool, request) +// if h != pool[0] { +// t.Error("Expected ip hash policy host to be the first host.") +// } +// request.RemoteAddr = "172.0.0.4" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } + +// // We should be able to resize the host pool and still be able to predict +// // where a request will be routed with the same IP's used above +// pool = []*UpstreamHost{ +// { +// Name: workableServer.URL, // this should resolve (healthcheck test) +// }, +// { +// Name: "http://localhost:99998", // this shouldn't +// }, +// } +// pool = HostPool(pool) +// request.RemoteAddr = "172.0.0.1:80" +// h = ipHash.Select(pool, request) +// if h != pool[0] { +// t.Error("Expected ip hash policy host to be the first host.") +// } +// request.RemoteAddr = "172.0.0.2:80" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } +// request.RemoteAddr = "172.0.0.3:80" +// h = ipHash.Select(pool, request) +// if h != pool[0] { +// t.Error("Expected ip hash policy host to be the first host.") +// } +// request.RemoteAddr = "172.0.0.4:80" +// h = ipHash.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected ip hash policy host to be the second host.") +// } + +// // We should get nil when there are no healthy hosts +// pool[0].Unhealthy = 1 +// pool[1].Unhealthy = 1 +// h = ipHash.Select(pool, request) +// if h != nil { +// t.Error("Expected ip hash policy host to be nil.") +// } +// } + +// func TestFirstPolicy(t *testing.T) { +// pool := testPool() +// firstPolicy := &First{} +// req := httptest.NewRequest(http.MethodGet, "/", nil) + +// h := firstPolicy.Select(pool, req) +// if h != pool[0] { +// t.Error("Expected first policy host to be the first host.") +// } + +// pool[0].Unhealthy = 1 +// h = firstPolicy.Select(pool, req) +// if h != pool[1] { +// t.Error("Expected first policy host to be the second host.") +// } +// } + +// func TestUriPolicy(t *testing.T) { +// pool := testPool() +// uriPolicy := &URIHash{} + +// request := httptest.NewRequest(http.MethodGet, "/test", nil) +// h := uriPolicy.Select(pool, request) +// if h != pool[0] { +// t.Error("Expected uri policy host to be the first host.") +// } + +// pool[0].Unhealthy = 1 +// h = uriPolicy.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected uri policy host to be the first host.") +// } + +// request = httptest.NewRequest(http.MethodGet, "/test_2", nil) +// h = uriPolicy.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected uri policy host to be the second host.") +// } + +// // We should be able to resize the host pool and still be able to predict +// // where a request will be routed with the same URI's used above +// pool = []*UpstreamHost{ +// { +// Name: workableServer.URL, // this should resolve (healthcheck test) +// }, +// { +// Name: "http://localhost:99998", // this shouldn't +// }, +// } + +// request = httptest.NewRequest(http.MethodGet, "/test", nil) +// h = uriPolicy.Select(pool, request) +// if h != pool[0] { +// t.Error("Expected uri policy host to be the first host.") +// } + +// pool[0].Unhealthy = 1 +// h = uriPolicy.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected uri policy host to be the first host.") +// } + +// request = httptest.NewRequest(http.MethodGet, "/test_2", nil) +// h = uriPolicy.Select(pool, request) +// if h != pool[1] { +// t.Error("Expected uri policy host to be the second host.") +// } + +// pool[0].Unhealthy = 1 +// pool[1].Unhealthy = 1 +// h = uriPolicy.Select(pool, request) +// if h != nil { +// t.Error("Expected uri policy policy host to be nil.") +// } +// } + +// func TestHeaderPolicy(t *testing.T) { +// pool := testPool() +// tests := []struct { +// Name string +// Policy *Header +// RequestHeaderName string +// RequestHeaderValue string +// NilHost bool +// HostIndex int +// }{ +// {"empty config", &Header{""}, "", "", true, 0}, +// {"empty config+header+value", &Header{""}, "Affinity", "somevalue", true, 0}, +// {"empty config+header", &Header{""}, "Affinity", "", true, 0}, + +// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 1}, +// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 2}, +// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 0}, + +// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue", false, 1}, +// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue2", false, 0}, +// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue3", false, 2}, +// {"hash route with empty value", &Header{"Affinity"}, "Affinity", "", false, 1}, +// } + +// for idx, test := range tests { +// request, _ := http.NewRequest("GET", "/", nil) +// if test.RequestHeaderName != "" { +// request.Header.Add(test.RequestHeaderName, test.RequestHeaderValue) +// } + +// host := test.Policy.Select(pool, request) +// if test.NilHost && host != nil { +// t.Errorf("%d: Expected host to be nil", idx) +// } +// if !test.NilHost && host == nil { +// t.Errorf("%d: Did not expect host to be nil", idx) +// } +// if !test.NilHost && host != pool[test.HostIndex] { +// t.Errorf("%d: Expected Header policy to be host %d", idx, test.HostIndex) +// } +// } +// } diff --git a/modules/caddyhttp/reverseproxy/upstream.go b/modules/caddyhttp/reverseproxy/upstream.go deleted file mode 100755 index 1f0693e..0000000 --- a/modules/caddyhttp/reverseproxy/upstream.go +++ /dev/null @@ -1,450 +0,0 @@ -// Copyright 2015 Matthew Holt and The Caddy Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package reverseproxy implements a load-balanced reverse proxy. -package reverseproxy - -import ( - "context" - "encoding/json" - "fmt" - "math/rand" - "net" - "net/http" - "net/url" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/caddyserver/caddy/v2" - "github.com/caddyserver/caddy/v2/modules/caddyhttp" -) - -// CircuitBreaker defines the functionality of a circuit breaker module. -type CircuitBreaker interface { - Ok() bool - RecordMetric(statusCode int, latency time.Duration) -} - -type noopCircuitBreaker struct{} - -func (ncb noopCircuitBreaker) RecordMetric(statusCode int, latency time.Duration) {} -func (ncb noopCircuitBreaker) Ok() bool { - return true -} - -const ( - // TypeBalanceRoundRobin represents the value to use for configuring a load balanced reverse proxy to use round robin load balancing. - TypeBalanceRoundRobin = iota - - // TypeBalanceRandom represents the value to use for configuring a load balanced reverse proxy to use random load balancing. - TypeBalanceRandom - - // TODO: add random with two choices - - // msgNoHealthyUpstreams is returned if there are no upstreams that are healthy to proxy a request to - msgNoHealthyUpstreams = "No healthy upstreams." - - // by default perform health checks every 30 seconds - defaultHealthCheckDur = time.Second * 30 - - // used when an upstream is unhealthy, health checks can be configured to perform at a faster rate - defaultFastHealthCheckDur = time.Second * 1 -) - -var ( - // defaultTransport is the default transport to use for the reverse proxy. - defaultTransport = &http.Transport{ - Dial: (&net.Dialer{ - Timeout: 5 * time.Second, - }).Dial, - TLSHandshakeTimeout: 5 * time.Second, - } - - // defaultHTTPClient is the default http client to use for the healthchecker. - defaultHTTPClient = &http.Client{ - Timeout: time.Second * 10, - Transport: defaultTransport, - } - - // typeMap maps caddy load balance configuration to the internal representation of the loadbalance algorithm type. - typeMap = map[string]int{ - "round_robin": TypeBalanceRoundRobin, - "random": TypeBalanceRandom, - } -) - -// NewLoadBalancedReverseProxy returns a collection of Upstreams that are to be loadbalanced. -func NewLoadBalancedReverseProxy(lb *LoadBalanced, ctx caddy.Context) error { - // set defaults - if lb.NoHealthyUpstreamsMessage == "" { - lb.NoHealthyUpstreamsMessage = msgNoHealthyUpstreams - } - - if lb.TryInterval == "" { - lb.TryInterval = "20s" - } - - // set request retry interval - ti, err := time.ParseDuration(lb.TryInterval) - if err != nil { - return fmt.Errorf("NewLoadBalancedReverseProxy: %v", err.Error()) - } - lb.tryInterval = ti - - // set load balance algorithm - t, ok := typeMap[lb.LoadBalanceType] - if !ok { - t = TypeBalanceRandom - } - lb.loadBalanceType = t - - // setup each upstream - var us []*upstream - for _, uc := range lb.Upstreams { - // pass the upstream decr and incr methods to keep track of unhealthy nodes - nu, err := newUpstream(uc, lb.decrUnhealthy, lb.incrUnhealthy) - if err != nil { - return err - } - - // setup any configured circuit breakers - var cbModule = "http.handlers.reverse_proxy.circuit_breaker" - var cb CircuitBreaker - - if uc.CircuitBreaker != nil { - if _, err := caddy.GetModule(cbModule); err == nil { - val, err := ctx.LoadModule(cbModule, uc.CircuitBreaker) - if err == nil { - cbv, ok := val.(CircuitBreaker) - if ok { - cb = cbv - } else { - fmt.Printf("\nerr: %v; cannot load circuit_breaker, using noop", err.Error()) - cb = noopCircuitBreaker{} - } - } else { - fmt.Printf("\nerr: %v; cannot load circuit_breaker, using noop", err.Error()) - cb = noopCircuitBreaker{} - } - } else { - fmt.Println("circuit_breaker module not loaded, using noop") - cb = noopCircuitBreaker{} - } - } else { - cb = noopCircuitBreaker{} - } - nu.CB = cb - - // start a healthcheck worker which will periodically check to see if an upstream is healthy - // to proxy requests to. - nu.healthChecker = NewHealthCheckWorker(nu, defaultHealthCheckDur, defaultHTTPClient) - - // TODO :- if path is empty why does this empty the entire Target? - // nu.Target.Path = uc.HealthCheckPath - - nu.healthChecker.ScheduleChecks(nu.Target.String()) - lb.HealthCheckers = append(lb.HealthCheckers, nu.healthChecker) - - us = append(us, nu) - } - - lb.upstreams = us - - return nil -} - -// LoadBalanced represents a collection of upstream hosts that are loadbalanced. It -// contains multiple features like health checking and circuit breaking functionality -// for upstreams. -type LoadBalanced struct { - mu sync.Mutex - numUnhealthy int32 - selectedServer int // used during round robin load balancing - loadBalanceType int - tryInterval time.Duration - upstreams []*upstream - - // The following struct fields are set by caddy configuration. - // TryInterval is the max duration for which request retrys will be performed for a request. - TryInterval string `json:"try_interval,omitempty"` - - // Upstreams are the configs for upstream hosts - Upstreams []*UpstreamConfig `json:"upstreams,omitempty"` - - // LoadBalanceType is the string representation of what loadbalancing algorithm to use. i.e. "random" or "round_robin". - LoadBalanceType string `json:"load_balance_type,omitempty"` - - // NoHealthyUpstreamsMessage is returned as a response when there are no healthy upstreams to loadbalance to. - NoHealthyUpstreamsMessage string `json:"no_healthy_upstreams_message,omitempty"` - - // TODO :- store healthcheckers as package level state where each upstream gets a single healthchecker - // currently a healthchecker is created for each upstream defined, even if a healthchecker was previously created - // for that upstream - HealthCheckers []*HealthChecker `json:"health_checkers,omitempty"` -} - -// Cleanup stops all health checkers on a loadbalanced reverse proxy. -func (lb *LoadBalanced) Cleanup() error { - for _, hc := range lb.HealthCheckers { - hc.Stop() - } - - return nil -} - -// Provision sets up a new loadbalanced reverse proxy. -func (lb *LoadBalanced) Provision(ctx caddy.Context) error { - return NewLoadBalancedReverseProxy(lb, ctx) -} - -// ServeHTTP implements the caddyhttp.MiddlewareHandler interface to -// dispatch an HTTP request to the proper server. -func (lb *LoadBalanced) ServeHTTP(w http.ResponseWriter, r *http.Request, _ caddyhttp.Handler) error { - // ensure requests don't hang if an upstream does not respond or is not eventually healthy - var u *upstream - var done bool - - retryTimer := time.NewTicker(lb.tryInterval) - defer retryTimer.Stop() - - go func() { - select { - case <-retryTimer.C: - done = true - } - }() - - // keep trying to get an available upstream to process the request - for { - switch lb.loadBalanceType { - case TypeBalanceRandom: - u = lb.random() - case TypeBalanceRoundRobin: - u = lb.roundRobin() - } - - // if we can't get an upstream and our retry interval has ended return an error response - if u == nil && done { - w.WriteHeader(http.StatusBadGateway) - fmt.Fprint(w, lb.NoHealthyUpstreamsMessage) - - return fmt.Errorf(msgNoHealthyUpstreams) - } - - // attempt to get an available upstream - if u == nil { - continue - } - - start := time.Now() - - // if we get an error retry until we get a healthy upstream - res, err := u.ReverseProxy.ServeHTTP(w, r) - if err != nil { - if err == context.Canceled { - return nil - } - - continue - } - - // record circuit breaker metrics - go u.CB.RecordMetric(res.StatusCode, time.Now().Sub(start)) - - return nil - } -} - -// incrUnhealthy increments the amount of unhealthy nodes in a loadbalancer. -func (lb *LoadBalanced) incrUnhealthy() { - atomic.AddInt32(&lb.numUnhealthy, 1) -} - -// decrUnhealthy decrements the amount of unhealthy nodes in a loadbalancer. -func (lb *LoadBalanced) decrUnhealthy() { - atomic.AddInt32(&lb.numUnhealthy, -1) -} - -// roundRobin implements a round robin load balancing algorithm to select -// which server to forward requests to. -func (lb *LoadBalanced) roundRobin() *upstream { - if atomic.LoadInt32(&lb.numUnhealthy) == int32(len(lb.upstreams)) { - return nil - } - - selected := lb.upstreams[lb.selectedServer] - - lb.mu.Lock() - lb.selectedServer++ - if lb.selectedServer >= len(lb.upstreams) { - lb.selectedServer = 0 - } - lb.mu.Unlock() - - if selected.IsHealthy() && selected.CB.Ok() { - return selected - } - - return nil -} - -// random implements a random server selector for load balancing. -func (lb *LoadBalanced) random() *upstream { - if atomic.LoadInt32(&lb.numUnhealthy) == int32(len(lb.upstreams)) { - return nil - } - - n := rand.Int() % len(lb.upstreams) - selected := lb.upstreams[n] - - if selected.IsHealthy() && selected.CB.Ok() { - return selected - } - - return nil -} - -// UpstreamConfig represents the config of an upstream. -type UpstreamConfig struct { - // Host is the host name of the upstream server. - Host string `json:"host,omitempty"` - - // FastHealthCheckDuration is the duration for which a health check is performed when a node is considered unhealthy. - FastHealthCheckDuration string `json:"fast_health_check_duration,omitempty"` - - CircuitBreaker json.RawMessage `json:"circuit_breaker,omitempty"` - - // // CircuitBreakerConfig is the config passed to setup a circuit breaker. - // CircuitBreakerConfig *circuitbreaker.Config `json:"circuit_breaker,omitempty"` - circuitbreaker CircuitBreaker - - // HealthCheckDuration is the default duration for which a health check is performed. - HealthCheckDuration string `json:"health_check_duration,omitempty"` - - // HealthCheckPath is the path at the upstream host to use for healthchecks. - HealthCheckPath string `json:"health_check_path,omitempty"` -} - -// upstream represents an upstream host. -type upstream struct { - Healthy int32 // 0 = false, 1 = true - Target *url.URL - ReverseProxy *ReverseProxy - Incr func() - Decr func() - CB CircuitBreaker - healthChecker *HealthChecker - healthCheckDur time.Duration - fastHealthCheckDur time.Duration -} - -// newUpstream returns a new upstream. -func newUpstream(uc *UpstreamConfig, d func(), i func()) (*upstream, error) { - host := strings.TrimSpace(uc.Host) - protoIdx := strings.Index(host, "://") - if protoIdx == -1 || len(host[:protoIdx]) == 0 { - return nil, fmt.Errorf("protocol is required for host") - } - - hostURL, err := url.Parse(host) - if err != nil { - return nil, err - } - - // parse healthcheck durations - hcd, err := time.ParseDuration(uc.HealthCheckDuration) - if err != nil { - hcd = defaultHealthCheckDur - } - - fhcd, err := time.ParseDuration(uc.FastHealthCheckDuration) - if err != nil { - fhcd = defaultFastHealthCheckDur - } - - u := upstream{ - healthCheckDur: hcd, - fastHealthCheckDur: fhcd, - Target: hostURL, - Decr: d, - Incr: i, - Healthy: int32(0), // assume is unhealthy on start - } - - u.ReverseProxy = newReverseProxy(hostURL, u.SetHealthiness) - return &u, nil -} - -// SetHealthiness sets whether an upstream is healthy or not. The health check worker is updated to -// perform checks faster if a node is unhealthy. -func (u *upstream) SetHealthiness(ok bool) { - h := atomic.LoadInt32(&u.Healthy) - var wasHealthy bool - if h == 1 { - wasHealthy = true - } else { - wasHealthy = false - } - - if ok { - u.healthChecker.Ticker = time.NewTicker(u.healthCheckDur) - - if !wasHealthy { - atomic.AddInt32(&u.Healthy, 1) - u.Decr() - } - } else { - u.healthChecker.Ticker = time.NewTicker(u.fastHealthCheckDur) - - if wasHealthy { - atomic.AddInt32(&u.Healthy, -1) - u.Incr() - } - } -} - -// IsHealthy returns whether an Upstream is healthy or not. -func (u *upstream) IsHealthy() bool { - i := atomic.LoadInt32(&u.Healthy) - if i == 1 { - return true - } - - return false -} - -// newReverseProxy returns a new reverse proxy handler. -func newReverseProxy(target *url.URL, setHealthiness func(bool)) *ReverseProxy { - errorHandler := func(w http.ResponseWriter, r *http.Request, err error) { - // we don't need to worry about cancelled contexts since this doesn't necessarilly mean that - // the upstream is unhealthy. - if err != context.Canceled { - setHealthiness(false) - } - } - - rp := NewSingleHostReverseProxy(target) - rp.ErrorHandler = errorHandler - rp.Transport = defaultTransport // use default transport that times out in 5 seconds - return rp -} - -// Interface guards -var ( - _ caddyhttp.MiddlewareHandler = (*LoadBalanced)(nil) - _ caddy.Provisioner = (*LoadBalanced)(nil) - _ caddy.CleanerUpper = (*LoadBalanced)(nil) -) -- cgit v1.2.3 From ccfb12347b1d2f65b279352116527df430e0fba6 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Tue, 3 Sep 2019 12:10:11 -0600 Subject: reverse_proxy: Implement active health checks --- modules/caddyhttp/reverseproxy/healthchecks.go | 220 +++++++++++++++++++++++++ modules/caddyhttp/reverseproxy/reverseproxy.go | 155 ++++++++--------- 2 files changed, 301 insertions(+), 74 deletions(-) create mode 100644 modules/caddyhttp/reverseproxy/healthchecks.go (limited to 'modules') diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go new file mode 100644 index 0000000..96649a4 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/healthchecks.go @@ -0,0 +1,220 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import ( + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "net/url" + "regexp" + "strconv" + "time" + + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" +) + +type HealthChecks struct { + Active *ActiveHealthChecks `json:"active,omitempty"` + Passive *PassiveHealthChecks `json:"passive,omitempty"` +} + +type ActiveHealthChecks struct { + Path string `json:"path,omitempty"` + Port int `json:"port,omitempty"` + Interval caddy.Duration `json:"interval,omitempty"` + Timeout caddy.Duration `json:"timeout,omitempty"` + MaxSize int64 `json:"max_size,omitempty"` + ExpectStatus int `json:"expect_status,omitempty"` + ExpectBody string `json:"expect_body,omitempty"` + + stopChan chan struct{} + httpClient *http.Client + bodyRegexp *regexp.Regexp +} + +type PassiveHealthChecks struct { + MaxFails int `json:"max_fails,omitempty"` + FailDuration caddy.Duration `json:"fail_duration,omitempty"` + UnhealthyRequestCount int `json:"unhealthy_request_count,omitempty"` + UnhealthyStatus []int `json:"unhealthy_status,omitempty"` + UnhealthyLatency caddy.Duration `json:"unhealthy_latency,omitempty"` +} + +func (h *Handler) activeHealthChecker() { + ticker := time.NewTicker(time.Duration(h.HealthChecks.Active.Interval)) + h.doActiveHealthChecksForAllHosts() + for { + select { + case <-ticker.C: + h.doActiveHealthChecksForAllHosts() + case <-h.HealthChecks.Active.stopChan: + ticker.Stop() + return + } + } +} + +func (h *Handler) doActiveHealthChecksForAllHosts() { + hosts.Range(func(key, value interface{}) bool { + addr := key.(string) + host := value.(Host) + + go func(addr string, host Host) { + err := h.doActiveHealthCheck(addr, host) + if err != nil { + log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", addr, err) + } + }(addr, host) + + // continue to iterate all hosts + return true + }) +} + +// doActiveHealthCheck performs a health check to host which +// can be reached at address hostAddr. The actual address for +// the request will be built according to active health checker +// config. The health status of the host will be updated +// according to whether it passes the health check. An error is +// returned only if the health check fails to occur or if marking +// the host's health status fails. +func (h *Handler) doActiveHealthCheck(hostAddr string, host Host) error { + // create the URL for the health check + u, err := url.Parse(hostAddr) + if err != nil { + return err + } + if h.HealthChecks.Active.Path != "" { + u.Path = h.HealthChecks.Active.Path + } + if h.HealthChecks.Active.Port != 0 { + portStr := strconv.Itoa(h.HealthChecks.Active.Port) + u.Host = net.JoinHostPort(u.Hostname(), portStr) + } + + req, err := http.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + return err + } + + // do the request, careful to tame the response body + resp, err := h.HealthChecks.Active.httpClient.Do(req) + if err != nil { + log.Printf("[INFO] reverse_proxy: active health check: %s is down (HTTP request failed: %v)", hostAddr, err) + _, err2 := host.SetHealthy(false) + if err2 != nil { + return fmt.Errorf("marking unhealthy: %v", err2) + } + return nil + } + var body io.Reader = resp.Body + if h.HealthChecks.Active.MaxSize > 0 { + body = io.LimitReader(body, h.HealthChecks.Active.MaxSize) + } + defer func() { + // drain any remaining body so connection can be re-used + io.Copy(ioutil.Discard, body) + resp.Body.Close() + }() + + // if status code is outside criteria, mark down + if h.HealthChecks.Active.ExpectStatus > 0 { + if !caddyhttp.StatusCodeMatches(resp.StatusCode, h.HealthChecks.Active.ExpectStatus) { + log.Printf("[INFO] reverse_proxy: active health check: %s is down (status code %d unexpected)", hostAddr, resp.StatusCode) + _, err := host.SetHealthy(false) + if err != nil { + return fmt.Errorf("marking unhealthy: %v", err) + } + return nil + } + } else if resp.StatusCode < 200 || resp.StatusCode >= 400 { + log.Printf("[INFO] reverse_proxy: active health check: %s is down (status code %d out of tolerances)", hostAddr, resp.StatusCode) + _, err := host.SetHealthy(false) + if err != nil { + return fmt.Errorf("marking unhealthy: %v", err) + } + return nil + } + + // if body does not match regex, mark down + if h.HealthChecks.Active.bodyRegexp != nil { + bodyBytes, err := ioutil.ReadAll(body) + if err != nil { + log.Printf("[INFO] reverse_proxy: active health check: %s is down (failed to read response body)", hostAddr) + _, err := host.SetHealthy(false) + if err != nil { + return fmt.Errorf("marking unhealthy: %v", err) + } + return nil + } + if !h.HealthChecks.Active.bodyRegexp.Match(bodyBytes) { + log.Printf("[INFO] reverse_proxy: active health check: %s is down (response body failed expectations)", hostAddr) + _, err := host.SetHealthy(false) + if err != nil { + return fmt.Errorf("marking unhealthy: %v", err) + } + return nil + } + } + + // passed health check parameters, so mark as healthy + swapped, err := host.SetHealthy(true) + if swapped { + log.Printf("[INFO] reverse_proxy: active health check: %s is back up", hostAddr) + } + if err != nil { + return fmt.Errorf("marking healthy: %v", err) + } + + return nil +} + +// countFailure is used with passive health checks. It +// remembers 1 failure for upstream for the configured +// duration. If passive health checks are disabled or +// failure expiry is 0, this is a no-op. +func (h Handler) countFailure(upstream *Upstream) { + // only count failures if passive health checking is enabled + // and if failures are configured have a non-zero expiry + if h.HealthChecks == nil || h.HealthChecks.Passive == nil { + return + } + failDuration := time.Duration(h.HealthChecks.Passive.FailDuration) + if failDuration == 0 { + return + } + + // count failure immediately + err := upstream.Host.CountFail(1) + if err != nil { + log.Printf("[ERROR] proxy: upstream %s: counting failure: %v", + upstream.hostURL, err) + } + + // forget it later + go func(host Host, failDuration time.Duration) { + time.Sleep(failDuration) + err := host.CountFail(-1) + if err != nil { + log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v", + upstream.hostURL, err) + } + }(upstream.Host, failDuration) +} diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index e312d71..ebf6ac1 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -15,15 +15,14 @@ package reverseproxy import ( - "bytes" "context" "encoding/json" "fmt" "io" - "log" "net" "net/http" "net/url" + "regexp" "strings" "sync" "sync/atomic" @@ -90,11 +89,41 @@ func (h *Handler) Provision(ctx caddy.Context) error { if h.LoadBalancing.TryDuration > 0 && h.LoadBalancing.TryInterval == 0 { // a non-zero try_duration with a zero try_interval // will always spin the CPU for try_duration if the - // upstream is local or low-latency; default to some - // sane waiting period before try attempts + // upstream is local or low-latency; avoid that by + // defaulting to a sane wait period between attempts h.LoadBalancing.TryInterval = caddy.Duration(250 * time.Millisecond) } + // if active health checks are enabled, configure them and start a worker + if h.HealthChecks != nil && + h.HealthChecks.Active != nil && + (h.HealthChecks.Active.Path != "" || h.HealthChecks.Active.Port != 0) { + timeout := time.Duration(h.HealthChecks.Active.Timeout) + if timeout == 0 { + timeout = 10 * time.Second + } + + h.HealthChecks.Active.stopChan = make(chan struct{}) + h.HealthChecks.Active.httpClient = &http.Client{ + Timeout: timeout, + Transport: h.Transport, + } + + if h.HealthChecks.Active.Interval == 0 { + h.HealthChecks.Active.Interval = caddy.Duration(30 * time.Second) + } + + if h.HealthChecks.Active.ExpectBody != "" { + var err error + h.HealthChecks.Active.bodyRegexp, err = regexp.Compile(h.HealthChecks.Active.ExpectBody) + if err != nil { + return fmt.Errorf("expect_body: compiling regular expression: %v", err) + } + } + + go h.activeHealthChecker() + } + for _, upstream := range h.Upstreams { // url parser requires a scheme if !strings.Contains(upstream.Address, "://") { @@ -130,8 +159,6 @@ func (h *Handler) Provision(ctx caddy.Context) error { upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount } - // TODO: active health checks - if h.HealthChecks != nil { // upstreams need independent access to the passive // health check policy so they can, you know, passively @@ -143,11 +170,20 @@ func (h *Handler) Provision(ctx caddy.Context) error { return nil } +// Cleanup cleans up the resources made by h during provisioning. func (h *Handler) Cleanup() error { - // TODO: finish this up, make sure it takes care of any active health checkers or whatever + // stop the active health checker + if h.HealthChecks != nil && + h.HealthChecks.Active != nil && + h.HealthChecks.Active.stopChan != nil { + close(h.HealthChecks.Active.stopChan) + } + + // remove hosts from our config from the pool for _, upstream := range h.Upstreams { hosts.Delete(upstream.hostURL.String()) } + return nil } @@ -539,38 +575,6 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, er } } -// countFailure remembers 1 failure for upstream for the -// configured duration. If passive health checks are -// disabled or failure expiry is 0, this is a no-op. -func (h Handler) countFailure(upstream *Upstream) { - // only count failures if passive health checking is enabled - // and if failures are configured have a non-zero expiry - if h.HealthChecks == nil || h.HealthChecks.Passive == nil { - return - } - failDuration := time.Duration(h.HealthChecks.Passive.FailDuration) - if failDuration == 0 { - return - } - - // count failure immediately - err := upstream.Host.CountFail(1) - if err != nil { - log.Printf("[ERROR] proxy: upstream %s: counting failure: %v", - upstream.hostURL, err) - } - - // forget it later - go func(host Host, failDuration time.Duration) { - time.Sleep(failDuration) - err := host.CountFail(-1) - if err != nil { - log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v", - upstream.hostURL, err) - } - }(upstream.Host, failDuration) -} - type writeFlusher interface { io.Writer http.Flusher @@ -722,29 +726,6 @@ type Selector interface { Select(HostPool, *http.Request) *Upstream } -type HealthChecks struct { - Active *ActiveHealthChecks `json:"active,omitempty"` - Passive *PassiveHealthChecks `json:"passive,omitempty"` -} - -type ActiveHealthChecks struct { - Path string `json:"path,omitempty"` - Port int `json:"port,omitempty"` - Interval caddy.Duration `json:"interval,omitempty"` - Timeout caddy.Duration `json:"timeout,omitempty"` - MaxSize int `json:"max_size,omitempty"` - ExpectStatus int `json:"expect_status,omitempty"` - ExpectBody string `json:"expect_body,omitempty"` -} - -type PassiveHealthChecks struct { - MaxFails int `json:"max_fails,omitempty"` - FailDuration caddy.Duration `json:"fail_duration,omitempty"` - UnhealthyRequestCount int `json:"unhealthy_request_count,omitempty"` - UnhealthyStatus []int `json:"unhealthy_status,omitempty"` - UnhealthyLatency caddy.Duration `json:"unhealthy_latency,omitempty"` -} - // Hop-by-hop headers. These are removed when sent to the backend. // As of RFC 7230, hop-by-hop headers are required to appear in the // Connection header field. These are the headers defined by the @@ -762,22 +743,33 @@ var hopHeaders = []string{ "Upgrade", } -var bufPool = sync.Pool{ - New: func() interface{} { - return new(bytes.Buffer) - }, -} - -////////////////////////////////// -// TODO: - +// Host represents a remote host which can be proxied to. +// Its methods must be safe for concurrent use. type Host interface { + // NumRequests returns the numnber of requests + // currently in process with the host. NumRequests() int + + // Fails returns the count of recent failures. Fails() int + + // Unhealthy returns true if the backend is unhealthy. Unhealthy() bool + // CountRequest counts the given number of requests + // as currently in process with the host. The count + // should not go below 0. CountRequest(int) error + + // CountFail counts the given number of failures + // with the host. The count should not go below 0. CountFail(int) error + + // SetHealthy marks the host as either healthy (true) + // or unhealthy (false). If the given status is the + // same, this should be a no-op. It returns true if + // the given status was different, false otherwise. + SetHealthy(bool) (bool, error) } type HostPool []*Upstream @@ -788,13 +780,13 @@ type upstreamHost struct { unhealthy int32 } -func (uh upstreamHost) NumRequests() int { +func (uh *upstreamHost) NumRequests() int { return int(atomic.LoadInt64(&uh.numRequests)) } -func (uh upstreamHost) Fails() int { +func (uh *upstreamHost) Fails() int { return int(atomic.LoadInt64(&uh.fails)) } -func (uh upstreamHost) Unhealthy() bool { +func (uh *upstreamHost) Unhealthy() bool { return atomic.LoadInt32(&uh.unhealthy) == 1 } func (uh *upstreamHost) CountRequest(delta int) error { @@ -811,6 +803,14 @@ func (uh *upstreamHost) CountFail(delta int) error { } return nil } +func (uh *upstreamHost) SetHealthy(healthy bool) (bool, error) { + var unhealthy, compare int32 = 1, 0 + if healthy { + unhealthy, compare = 0, 1 + } + swapped := atomic.CompareAndSwapInt32(&uh.unhealthy, compare, unhealthy) + return swapped, nil +} type Upstream struct { Host `json:"-"` @@ -854,6 +854,13 @@ var hosts = caddy.NewUsagePool() type UpstreamProvider interface { } +// TODO: see if we can use this +// var bufPool = sync.Pool{ +// New: func() interface{} { +// return new(bytes.Buffer) +// }, +// } + // Interface guards var ( _ caddyhttp.MiddlewareHandler = (*Handler)(nil) -- cgit v1.2.3 From 4a1e1649bc985e9658d326ed433a101d7d79ae30 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Tue, 3 Sep 2019 15:26:09 -0600 Subject: reverse_proxy: Implement remaining TLS config for proxy to backend --- modules/caddyhttp/reverseproxy/httptransport.go | 86 +++++++++++++++++++++++-- 1 file changed, 81 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index 36dd776..999a352 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -15,8 +15,13 @@ package reverseproxy import ( + "crypto/tls" + "crypto/x509" + "encoding/base64" + "fmt" "net" "net/http" + "reflect" "time" "github.com/caddyserver/caddy/v2" @@ -76,7 +81,12 @@ func (h *HTTPTransport) Provision(ctx caddy.Context) error { if h.TLS != nil { rt.TLSHandshakeTimeout = time.Duration(h.TLS.HandshakeTimeout) - // TODO: rest of TLS config + + var err error + rt.TLSClientConfig, err = h.TLS.MakeTLSClientConfig() + if err != nil { + return fmt.Errorf("making TLS client config: %v", err) + } } if h.KeepAlive != nil { @@ -103,11 +113,77 @@ func (h HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { return h.RoundTripper.RoundTrip(req) } +func defaultTLSConfig() *tls.Config { + return &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, // TODO: ensure this makes HTTP/2 work + } +} + type TLSConfig struct { - CAPool []string `json:"ca_pool,omitempty"` - ClientCertificate string `json:"client_certificate,omitempty"` - InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` - HandshakeTimeout caddy.Duration `json:"handshake_timeout,omitempty"` + RootCAPool []string `json:"root_ca_pool,omitempty"` + // TODO: Should the client cert+key config use caddytls.CertificateLoader modules? + ClientCertificateFile string `json:"client_certificate_file,omitempty"` + ClientCertificateKeyFile string `json:"client_certificate_key_file,omitempty"` + InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` + HandshakeTimeout caddy.Duration `json:"handshake_timeout,omitempty"` +} + +// MakeTLSClientConfig returns a tls.Config usable by a client to a backend. +// If there is no custom TLS configuration, a nil config may be returned. +func (t TLSConfig) MakeTLSClientConfig() (*tls.Config, error) { + cfg := new(tls.Config) + + // client auth + if t.ClientCertificateFile != "" && t.ClientCertificateKeyFile == "" { + return nil, fmt.Errorf("client_certificate_file specified without client_certificate_key_file") + } + if t.ClientCertificateFile == "" && t.ClientCertificateKeyFile != "" { + return nil, fmt.Errorf("client_certificate_key_file specified without client_certificate_file") + } + if t.ClientCertificateFile != "" && t.ClientCertificateKeyFile != "" { + cert, err := tls.LoadX509KeyPair(t.ClientCertificateFile, t.ClientCertificateKeyFile) + if err != nil { + return nil, fmt.Errorf("loading client certificate key pair: %v", err) + } + cfg.Certificates = []tls.Certificate{cert} + } + + // trusted root CAs + if len(t.RootCAPool) > 0 { + rootPool := x509.NewCertPool() + for _, encodedCACert := range t.RootCAPool { + caCert, err := decodeBase64DERCert(encodedCACert) + if err != nil { + return nil, fmt.Errorf("parsing CA certificate: %v", err) + } + rootPool.AddCert(caCert) + } + cfg.RootCAs = rootPool + } + + // throw all security out the window + cfg.InsecureSkipVerify = t.InsecureSkipVerify + + // only return a config if it's not empty + if reflect.DeepEqual(cfg, new(tls.Config)) { + return nil, nil + } + + cfg.NextProtos = []string{"h2", "http/1.1"} // TODO: ensure that this actually enables HTTP/2 + + return cfg, nil +} + +// decodeBase64DERCert base64-decodes, then DER-decodes, certStr. +func decodeBase64DERCert(certStr string) (*x509.Certificate, error) { + // decode base64 + derBytes, err := base64.StdEncoding.DecodeString(certStr) + if err != nil { + return nil, err + } + + // parse the DER-encoded certificate + return x509.ParseCertificate(derBytes) } type KeepAlive struct { -- cgit v1.2.3 From 652460e03e11a037d9f86b09b3546c9e42733d2d Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Tue, 3 Sep 2019 16:56:09 -0600 Subject: Some cleanup and godoc --- modules/caddyhttp/reverseproxy/healthchecks.go | 12 + modules/caddyhttp/reverseproxy/hosts.go | 161 ++++++++++ modules/caddyhttp/reverseproxy/httptransport.go | 21 +- modules/caddyhttp/reverseproxy/reverseproxy.go | 355 +-------------------- .../caddyhttp/reverseproxy/selectionpolicies.go | 18 +- modules/caddyhttp/reverseproxy/streaming.go | 223 +++++++++++++ 6 files changed, 428 insertions(+), 362 deletions(-) create mode 100644 modules/caddyhttp/reverseproxy/hosts.go create mode 100644 modules/caddyhttp/reverseproxy/streaming.go (limited to 'modules') diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go index 96649a4..0b46d04 100644 --- a/modules/caddyhttp/reverseproxy/healthchecks.go +++ b/modules/caddyhttp/reverseproxy/healthchecks.go @@ -30,11 +30,15 @@ import ( "github.com/caddyserver/caddy/v2/modules/caddyhttp" ) +// HealthChecks holds configuration related to health checking. type HealthChecks struct { Active *ActiveHealthChecks `json:"active,omitempty"` Passive *PassiveHealthChecks `json:"passive,omitempty"` } +// ActiveHealthChecks holds configuration related to active +// health checks (that is, health checks which occur in a +// background goroutine independently). type ActiveHealthChecks struct { Path string `json:"path,omitempty"` Port int `json:"port,omitempty"` @@ -49,6 +53,9 @@ type ActiveHealthChecks struct { bodyRegexp *regexp.Regexp } +// PassiveHealthChecks holds configuration related to passive +// health checks (that is, health checks which occur during +// the normal flow of request proxying). type PassiveHealthChecks struct { MaxFails int `json:"max_fails,omitempty"` FailDuration caddy.Duration `json:"fail_duration,omitempty"` @@ -57,6 +64,9 @@ type PassiveHealthChecks struct { UnhealthyLatency caddy.Duration `json:"unhealthy_latency,omitempty"` } +// activeHealthChecker runs active health checks on a +// regular basis and blocks until +// h.HealthChecks.Active.stopChan is closed. func (h *Handler) activeHealthChecker() { ticker := time.NewTicker(time.Duration(h.HealthChecks.Active.Interval)) h.doActiveHealthChecksForAllHosts() @@ -71,6 +81,8 @@ func (h *Handler) activeHealthChecker() { } } +// doActiveHealthChecksForAllHosts immediately performs a +// health checks for all hosts in the global repository. func (h *Handler) doActiveHealthChecksForAllHosts() { hosts.Range(func(key, value interface{}) bool { addr := key.(string) diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go new file mode 100644 index 0000000..5100936 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/hosts.go @@ -0,0 +1,161 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import ( + "fmt" + "net/url" + "sync/atomic" + + "github.com/caddyserver/caddy/v2" +) + +// Host represents a remote host which can be proxied to. +// Its methods must be safe for concurrent use. +type Host interface { + // NumRequests returns the numnber of requests + // currently in process with the host. + NumRequests() int + + // Fails returns the count of recent failures. + Fails() int + + // Unhealthy returns true if the backend is unhealthy. + Unhealthy() bool + + // CountRequest counts the given number of requests + // as currently in process with the host. The count + // should not go below 0. + CountRequest(int) error + + // CountFail counts the given number of failures + // with the host. The count should not go below 0. + CountFail(int) error + + // SetHealthy marks the host as either healthy (true) + // or unhealthy (false). If the given status is the + // same, this should be a no-op. It returns true if + // the given status was different, false otherwise. + SetHealthy(bool) (bool, error) +} + +// UpstreamPool is a collection of upstreams. +type UpstreamPool []*Upstream + +// Upstream bridges this proxy's configuration to the +// state of the backend host it is correlated with. +type Upstream struct { + Host `json:"-"` + + Address string `json:"address,omitempty"` + MaxRequests int `json:"max_requests,omitempty"` + + // TODO: This could be really useful, to bind requests + // with certain properties to specific backends + // HeaderAffinity string + // IPAffinity string + + healthCheckPolicy *PassiveHealthChecks + hostURL *url.URL +} + +// Available returns true if the remote host +// is available to receive requests. +func (u *Upstream) Available() bool { + return u.Healthy() && !u.Full() +} + +// Healthy returns true if the remote host +// is currently known to be healthy or "up". +func (u *Upstream) Healthy() bool { + healthy := !u.Host.Unhealthy() + if healthy && u.healthCheckPolicy != nil { + healthy = u.Host.Fails() < u.healthCheckPolicy.MaxFails + } + return healthy +} + +// Full returns true if the remote host +// cannot receive more requests at this time. +func (u *Upstream) Full() bool { + return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests +} + +// URL returns the upstream host's endpoint URL. +func (u *Upstream) URL() *url.URL { + return u.hostURL +} + +// upstreamHost is the basic, in-memory representation +// of the state of a remote host. It implements the +// Host interface. +type upstreamHost struct { + numRequests int64 // must be first field to be 64-bit aligned on 32-bit systems (see https://golang.org/pkg/sync/atomic/#pkg-note-BUG) + fails int64 + unhealthy int32 +} + +// NumRequests returns the number of active requests to the upstream. +func (uh *upstreamHost) NumRequests() int { + return int(atomic.LoadInt64(&uh.numRequests)) +} + +// Fails returns the number of recent failures with the upstream. +func (uh *upstreamHost) Fails() int { + return int(atomic.LoadInt64(&uh.fails)) +} + +// Unhealthy returns whether the upstream is healthy. +func (uh *upstreamHost) Unhealthy() bool { + return atomic.LoadInt32(&uh.unhealthy) == 1 +} + +// CountRequest mutates the active request count by +// delta. It returns an error if the adjustment fails. +func (uh *upstreamHost) CountRequest(delta int) error { + result := atomic.AddInt64(&uh.numRequests, int64(delta)) + if result < 0 { + return fmt.Errorf("count below 0: %d", result) + } + return nil +} + +// CountFail mutates the recent failures count by +// delta. It returns an error if the adjustment fails. +func (uh *upstreamHost) CountFail(delta int) error { + result := atomic.AddInt64(&uh.fails, int64(delta)) + if result < 0 { + return fmt.Errorf("count below 0: %d", result) + } + return nil +} + +// SetHealthy sets the upstream has healthy or unhealthy +// and returns true if the value was different from before, +// or an error if the adjustment failed. +func (uh *upstreamHost) SetHealthy(healthy bool) (bool, error) { + var unhealthy, compare int32 = 1, 0 + if healthy { + unhealthy, compare = 0, 1 + } + swapped := atomic.CompareAndSwapInt32(&uh.unhealthy, compare, unhealthy) + return swapped, nil +} + +// hosts is the global repository for hosts that are +// currently in use by active configuration(s). This +// allows the state of remote hosts to be preserved +// through config reloads. +var hosts = caddy.NewUsagePool() diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index 999a352..d9dc457 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -31,14 +31,13 @@ func init() { caddy.RegisterModule(HTTPTransport{}) } -// TODO: This is the default transport, basically just http.Transport, but we define JSON struct tags... +// HTTPTransport is essentially a configuration wrapper for http.Transport. +// It defines a JSON structure useful when configuring the HTTP transport +// for Caddy's reverse proxy. type HTTPTransport struct { - // TODO: Actually this is where the TLS config should go, technically... - // as well as keepalives and dial timeouts... // TODO: It's possible that other transports (like fastcgi) might be // able to borrow/use at least some of these config fields; if so, // move them into a type called CommonTransport and embed it - TLS *TLSConfig `json:"tls,omitempty"` KeepAlive *KeepAlive `json:"keep_alive,omitempty"` Compression *bool `json:"compression,omitempty"` @@ -50,7 +49,6 @@ type HTTPTransport struct { MaxResponseHeaderSize int64 `json:"max_response_header_size,omitempty"` WriteBufferSize int `json:"write_buffer_size,omitempty"` ReadBufferSize int `json:"read_buffer_size,omitempty"` - // TODO: ProxyConnectHeader? RoundTripper http.RoundTripper `json:"-"` } @@ -63,6 +61,8 @@ func (HTTPTransport) CaddyModule() caddy.ModuleInfo { } } +// Provision sets up h.RoundTripper with a http.Transport +// that is ready to use. func (h *HTTPTransport) Provision(ctx caddy.Context) error { dialer := &net.Dialer{ Timeout: time.Duration(h.DialTimeout), @@ -109,16 +109,13 @@ func (h *HTTPTransport) Provision(ctx caddy.Context) error { return nil } +// RoundTrip implements http.RoundTripper with h.RoundTripper. func (h HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { return h.RoundTripper.RoundTrip(req) } -func defaultTLSConfig() *tls.Config { - return &tls.Config{ - NextProtos: []string{"h2", "http/1.1"}, // TODO: ensure this makes HTTP/2 work - } -} - +// TLSConfig holds configuration related to the +// TLS configuration for the transport/client. type TLSConfig struct { RootCAPool []string `json:"root_ca_pool,omitempty"` // TODO: Should the client cert+key config use caddytls.CertificateLoader modules? @@ -186,6 +183,7 @@ func decodeBase64DERCert(certStr string) (*x509.Certificate, error) { return x509.ParseCertificate(derBytes) } +// KeepAlive holds configuration pertaining to HTTP Keep-Alive. type KeepAlive struct { Enabled *bool `json:"enabled,omitempty"` ProbeInterval caddy.Duration `json:"probe_interval,omitempty"` @@ -200,7 +198,6 @@ var ( KeepAlive: 30 * time.Second, } - // TODO: does this need to be configured to enable HTTP/2? defaultTransport = &http.Transport{ DialContext: defaultDialer.DialContext, TLSHandshakeTimeout: 5 * time.Second, diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index ebf6ac1..ca54741 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -18,14 +18,11 @@ import ( "context" "encoding/json" "fmt" - "io" "net" "net/http" "net/url" "regexp" "strings" - "sync" - "sync/atomic" "time" "github.com/caddyserver/caddy/v2" @@ -37,14 +34,14 @@ func init() { caddy.RegisterModule(Handler{}) } +// Handler implements a highly configurable and production-ready reverse proxy. type Handler struct { TransportRaw json.RawMessage `json:"transport,omitempty"` LoadBalancing *LoadBalancing `json:"load_balancing,omitempty"` HealthChecks *HealthChecks `json:"health_checks,omitempty"` - // UpstreamStorageRaw json.RawMessage `json:"upstream_storage,omitempty"` // TODO: - Upstreams HostPool `json:"upstreams,omitempty"` + Upstreams UpstreamPool `json:"upstreams,omitempty"` + FlushInterval caddy.Duration `json:"flush_interval,omitempty"` - // UpstreamProvider UpstreamProvider `json:"-"` // TODO: Transport http.RoundTripper `json:"-"` } @@ -56,6 +53,7 @@ func (Handler) CaddyModule() caddy.ModuleInfo { } } +// Provision ensures that h is set up properly before use. func (h *Handler) Provision(ctx caddy.Context) error { if h.TransportRaw != nil { val, err := ctx.LoadModuleInline("protocol", "http.handlers.reverse_proxy.transport", h.TransportRaw) @@ -236,34 +234,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht // This assumes that no mutations of the request are performed // by h during or after proxying. func (h Handler) prepareRequest(req *http.Request) error { - // ctx := req.Context() - // TODO: do we need to support CloseNotifier? It was deprecated years ago. - // All this does is wrap CloseNotify with context cancel, for those responsewriters - // which didn't support context, but all the ones we'd use should nowadays, right? - // if cn, ok := rw.(http.CloseNotifier); ok { - // var cancel context.CancelFunc - // ctx, cancel = context.WithCancel(ctx) - // defer cancel() - // notifyChan := cn.CloseNotify() - // go func() { - // select { - // case <-notifyChan: - // cancel() - // case <-ctx.Done(): - // } - // }() - // } - - // TODO: do we need to call WithContext, since we won't be changing req.Context() above if we remove the CloseNotifier stuff? - // TODO: (This is where references to req were originally "outreq", a shallow clone, which I think is unnecessary in our case) - // req = req.WithContext(ctx) // includes shallow copies of maps, but okay if req.ContentLength == 0 { req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries } - // TODO: is this needed? - // req.Header = cloneHeader(req.Header) - req.Close = false // if User-Agent is not set by client, then explicitly @@ -315,10 +289,12 @@ func (h Handler) prepareRequest(req *http.Request) error { return nil } -// TODO: -// this code is the entry point to what was borrowed from the net/http/httputil package in the standard library. +// reverseProxy performs a round-trip to the given backend and processes the response with the client. +// (This method is mostly the beginning of what was borrowed from the net/http/httputil package in the +// Go standard library which was used as the foundation.) func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, upstream *Upstream) error { - // TODO: count this active request + upstream.Host.CountRequest(1) + defer upstream.Host.CountRequest(-1) // point the request to this upstream h.directRequest(req, upstream) @@ -448,202 +424,6 @@ func (h Handler) directRequest(req *http.Request, upstream *Upstream) { } } -func (h Handler) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { - reqUpType := upgradeType(req.Header) - resUpType := upgradeType(res.Header) - if reqUpType != resUpType { - // p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) - return - } - - copyHeader(res.Header, rw.Header()) - - hj, ok := rw.(http.Hijacker) - if !ok { - // p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) - return - } - backConn, ok := res.Body.(io.ReadWriteCloser) - if !ok { - // p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) - return - } - defer backConn.Close() - conn, brw, err := hj.Hijack() - if err != nil { - // p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) - return - } - defer conn.Close() - res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above - if err := res.Write(brw); err != nil { - // p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) - return - } - if err := brw.Flush(); err != nil { - // p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) - return - } - errc := make(chan error, 1) - spc := switchProtocolCopier{user: conn, backend: backConn} - go spc.copyToBackend(errc) - go spc.copyFromBackend(errc) - <-errc - return -} - -// flushInterval returns the p.FlushInterval value, conditionally -// overriding its value for a specific request/response. -func (h Handler) flushInterval(req *http.Request, res *http.Response) time.Duration { - resCT := res.Header.Get("Content-Type") - - // For Server-Sent Events responses, flush immediately. - // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream - if resCT == "text/event-stream" { - return -1 // negative means immediately - } - - // TODO: more specific cases? e.g. res.ContentLength == -1? - // return h.FlushInterval - return 0 -} - -func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { - if flushInterval != 0 { - if wf, ok := dst.(writeFlusher); ok { - mlw := &maxLatencyWriter{ - dst: wf, - latency: flushInterval, - } - defer mlw.stop() - dst = mlw - } - } - - // TODO: Figure out how we want to do this... using custom buffer pool type seems unnecessary - // or maybe it is, depending on how we want to handle errors, - // see: https://github.com/golang/go/issues/21814 - // buf := bufPool.Get().(*bytes.Buffer) - // buf.Reset() - // defer bufPool.Put(buf) - // _, err := io.CopyBuffer(dst, src, ) - var buf []byte - // if h.BufferPool != nil { - // buf = h.BufferPool.Get() - // defer h.BufferPool.Put(buf) - // } - _, err := h.copyBuffer(dst, src, buf) - return err -} - -// copyBuffer returns any write errors or non-EOF read errors, and the amount -// of bytes written. -func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { - if len(buf) == 0 { - buf = make([]byte, 32*1024) - } - var written int64 - for { - nr, rerr := src.Read(buf) - if rerr != nil && rerr != io.EOF && rerr != context.Canceled { - // TODO: this could be useful to know (indeed, it revealed an error in our - // fastcgi PoC earlier; but it's this single error report here that necessitates - // a function separate from io.CopyBuffer, since io.CopyBuffer does not distinguish - // between read or write errors; in a reverse proxy situation, write errors are not - // something we need to report to the client, but read errors are a problem on our - // end for sure. so we need to decide what we want.) - // p.logf("copyBuffer: ReverseProxy read error during body copy: %v", rerr) - } - if nr > 0 { - nw, werr := dst.Write(buf[:nr]) - if nw > 0 { - written += int64(nw) - } - if werr != nil { - return written, werr - } - if nr != nw { - return written, io.ErrShortWrite - } - } - if rerr != nil { - if rerr == io.EOF { - rerr = nil - } - return written, rerr - } - } -} - -type writeFlusher interface { - io.Writer - http.Flusher -} - -type maxLatencyWriter struct { - dst writeFlusher - latency time.Duration // non-zero; negative means to flush immediately - - mu sync.Mutex // protects t, flushPending, and dst.Flush - t *time.Timer - flushPending bool -} - -func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { - m.mu.Lock() - defer m.mu.Unlock() - n, err = m.dst.Write(p) - if m.latency < 0 { - m.dst.Flush() - return - } - if m.flushPending { - return - } - if m.t == nil { - m.t = time.AfterFunc(m.latency, m.delayedFlush) - } else { - m.t.Reset(m.latency) - } - m.flushPending = true - return -} - -func (m *maxLatencyWriter) delayedFlush() { - m.mu.Lock() - defer m.mu.Unlock() - if !m.flushPending { // if stop was called but AfterFunc already started this goroutine - return - } - m.dst.Flush() - m.flushPending = false -} - -func (m *maxLatencyWriter) stop() { - m.mu.Lock() - defer m.mu.Unlock() - m.flushPending = false - if m.t != nil { - m.t.Stop() - } -} - -// switchProtocolCopier exists so goroutines proxying data back and -// forth have nice names in stacks. -type switchProtocolCopier struct { - user, backend io.ReadWriter -} - -func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { - _, err := io.Copy(c.user, c.backend) - errc <- err -} - -func (c switchProtocolCopier) copyToBackend(errc chan<- error) { - _, err := io.Copy(c.backend, c.user) - errc <- err -} - // shouldPanicOnCopyError reports whether the reverse proxy should // panic with http.ErrAbortHandler. This is the right thing to do by // default, but Go 1.10 and earlier did not, so existing unit tests @@ -714,6 +494,7 @@ func removeConnectionHeaders(h http.Header) { } } +// LoadBalancing has parameters related to load balancing. type LoadBalancing struct { SelectionPolicyRaw json.RawMessage `json:"selection_policy,omitempty"` TryDuration caddy.Duration `json:"try_duration,omitempty"` @@ -722,8 +503,9 @@ type LoadBalancing struct { SelectionPolicy Selector `json:"-"` } +// Selector selects an available upstream from the pool. type Selector interface { - Select(HostPool, *http.Request) *Upstream + Select(UpstreamPool, *http.Request) *Upstream } // Hop-by-hop headers. These are removed when sent to the backend. @@ -743,117 +525,6 @@ var hopHeaders = []string{ "Upgrade", } -// Host represents a remote host which can be proxied to. -// Its methods must be safe for concurrent use. -type Host interface { - // NumRequests returns the numnber of requests - // currently in process with the host. - NumRequests() int - - // Fails returns the count of recent failures. - Fails() int - - // Unhealthy returns true if the backend is unhealthy. - Unhealthy() bool - - // CountRequest counts the given number of requests - // as currently in process with the host. The count - // should not go below 0. - CountRequest(int) error - - // CountFail counts the given number of failures - // with the host. The count should not go below 0. - CountFail(int) error - - // SetHealthy marks the host as either healthy (true) - // or unhealthy (false). If the given status is the - // same, this should be a no-op. It returns true if - // the given status was different, false otherwise. - SetHealthy(bool) (bool, error) -} - -type HostPool []*Upstream - -type upstreamHost struct { - numRequests int64 // must be first field to be 64-bit aligned on 32-bit systems (see https://golang.org/pkg/sync/atomic/#pkg-note-BUG) - fails int64 - unhealthy int32 -} - -func (uh *upstreamHost) NumRequests() int { - return int(atomic.LoadInt64(&uh.numRequests)) -} -func (uh *upstreamHost) Fails() int { - return int(atomic.LoadInt64(&uh.fails)) -} -func (uh *upstreamHost) Unhealthy() bool { - return atomic.LoadInt32(&uh.unhealthy) == 1 -} -func (uh *upstreamHost) CountRequest(delta int) error { - result := atomic.AddInt64(&uh.numRequests, int64(delta)) - if result < 0 { - return fmt.Errorf("count below 0: %d", result) - } - return nil -} -func (uh *upstreamHost) CountFail(delta int) error { - result := atomic.AddInt64(&uh.fails, int64(delta)) - if result < 0 { - return fmt.Errorf("count below 0: %d", result) - } - return nil -} -func (uh *upstreamHost) SetHealthy(healthy bool) (bool, error) { - var unhealthy, compare int32 = 1, 0 - if healthy { - unhealthy, compare = 0, 1 - } - swapped := atomic.CompareAndSwapInt32(&uh.unhealthy, compare, unhealthy) - return swapped, nil -} - -type Upstream struct { - Host `json:"-"` - - Address string `json:"address,omitempty"` - MaxRequests int `json:"max_requests,omitempty"` - - // TODO: This could be really cool, to say that requests with - // certain headers or from certain IPs always go to this upstream - // HeaderAffinity string - // IPAffinity string - - healthCheckPolicy *PassiveHealthChecks - - hostURL *url.URL -} - -func (u Upstream) Available() bool { - return u.Healthy() && !u.Full() -} - -func (u Upstream) Healthy() bool { - healthy := !u.Host.Unhealthy() - if healthy && u.healthCheckPolicy != nil { - healthy = u.Host.Fails() < u.healthCheckPolicy.MaxFails - } - return healthy -} - -func (u Upstream) Full() bool { - return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests -} - -func (u Upstream) URL() *url.URL { - return u.hostURL -} - -var hosts = caddy.NewUsagePool() - -// TODO: ... -type UpstreamProvider interface { -} - // TODO: see if we can use this // var bufPool = sync.Pool{ // New: func() interface{} { @@ -863,7 +534,7 @@ type UpstreamProvider interface { // Interface guards var ( - _ caddyhttp.MiddlewareHandler = (*Handler)(nil) _ caddy.Provisioner = (*Handler)(nil) _ caddy.CleanerUpper = (*Handler)(nil) + _ caddyhttp.MiddlewareHandler = (*Handler)(nil) ) diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go index e0518c9..9680583 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go @@ -52,7 +52,7 @@ func (RandomSelection) CaddyModule() caddy.ModuleInfo { } // Select returns an available host, if any. -func (r RandomSelection) Select(pool HostPool, request *http.Request) *Upstream { +func (r RandomSelection) Select(pool UpstreamPool, request *http.Request) *Upstream { // use reservoir sampling because the number of available // hosts isn't known: https://en.wikipedia.org/wiki/Reservoir_sampling var randomHost *Upstream @@ -87,6 +87,7 @@ func (RandomChoiceSelection) CaddyModule() caddy.ModuleInfo { } } +// Provision sets up r. func (r *RandomChoiceSelection) Provision(ctx caddy.Context) error { if r.Choose == 0 { r.Choose = 2 @@ -94,6 +95,7 @@ func (r *RandomChoiceSelection) Provision(ctx caddy.Context) error { return nil } +// Validate ensures that r's configuration is valid. func (r RandomChoiceSelection) Validate() error { if r.Choose < 2 { return fmt.Errorf("choose must be at least 2") @@ -102,7 +104,7 @@ func (r RandomChoiceSelection) Validate() error { } // Select returns an available host, if any. -func (r RandomChoiceSelection) Select(pool HostPool, _ *http.Request) *Upstream { +func (r RandomChoiceSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream { k := r.Choose if k > len(pool) { k = len(pool) @@ -142,7 +144,7 @@ func (LeastConnSelection) CaddyModule() caddy.ModuleInfo { // Select selects the up host with the least number of connections in the // pool. If more than one host has the same least number of connections, // one of the hosts is chosen at random. -func (LeastConnSelection) Select(pool HostPool, _ *http.Request) *Upstream { +func (LeastConnSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream { var bestHost *Upstream var count int var leastReqs int @@ -185,7 +187,7 @@ func (RoundRobinSelection) CaddyModule() caddy.ModuleInfo { } // Select returns an available host, if any. -func (r *RoundRobinSelection) Select(pool HostPool, _ *http.Request) *Upstream { +func (r *RoundRobinSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream { n := uint32(len(pool)) if n == 0 { return nil @@ -213,7 +215,7 @@ func (FirstSelection) CaddyModule() caddy.ModuleInfo { } // Select returns an available host, if any. -func (FirstSelection) Select(pool HostPool, _ *http.Request) *Upstream { +func (FirstSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream { for _, host := range pool { if host.Available() { return host @@ -235,7 +237,7 @@ func (IPHashSelection) CaddyModule() caddy.ModuleInfo { } // Select returns an available host, if any. -func (IPHashSelection) Select(pool HostPool, req *http.Request) *Upstream { +func (IPHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream { clientIP, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { clientIP = req.RemoteAddr @@ -256,7 +258,7 @@ func (URIHashSelection) CaddyModule() caddy.ModuleInfo { } // Select returns an available host, if any. -func (URIHashSelection) Select(pool HostPool, req *http.Request) *Upstream { +func (URIHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream { return hostByHashing(pool, req.RequestURI) } @@ -275,7 +277,7 @@ func (HeaderHashSelection) CaddyModule() caddy.ModuleInfo { } // Select returns an available host, if any. -func (s HeaderHashSelection) Select(pool HostPool, req *http.Request) *Upstream { +func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream { if s.Field == "" { return nil } diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go new file mode 100644 index 0000000..a3b711b --- /dev/null +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -0,0 +1,223 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Most of the code in this file was initially borrowed from the Go +// standard library, which has this copyright notice: +// Copyright 2011 The Go Authors. + +package reverseproxy + +import ( + "context" + "io" + "net/http" + "sync" + "time" +) + +func (h Handler) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { + reqUpType := upgradeType(req.Header) + resUpType := upgradeType(res.Header) + if reqUpType != resUpType { + // TODO: figure out our own error handling + // p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) + return + } + + copyHeader(res.Header, rw.Header()) + + hj, ok := rw.(http.Hijacker) + if !ok { + // p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + // p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) + return + } + defer backConn.Close() + conn, brw, err := hj.Hijack() + if err != nil { + // p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) + return + } + defer conn.Close() + res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above + if err := res.Write(brw); err != nil { + // p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) + return + } + if err := brw.Flush(); err != nil { + // p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) + return + } + errc := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + <-errc + return +} + +// flushInterval returns the p.FlushInterval value, conditionally +// overriding its value for a specific request/response. +func (h Handler) flushInterval(req *http.Request, res *http.Response) time.Duration { + resCT := res.Header.Get("Content-Type") + + // For Server-Sent Events responses, flush immediately. + // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream + if resCT == "text/event-stream" { + return -1 // negative means immediately + } + + // TODO: more specific cases? e.g. res.ContentLength == -1? (this TODO is from the std lib) + return time.Duration(h.FlushInterval) +} + +func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { + if flushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: flushInterval, + } + defer mlw.stop() + dst = mlw + } + } + + // TODO: Figure out how we want to do this... using custom buffer pool type seems unnecessary + // or maybe it is, depending on how we want to handle errors, + // see: https://github.com/golang/go/issues/21814 + // buf := bufPool.Get().(*bytes.Buffer) + // buf.Reset() + // defer bufPool.Put(buf) + // _, err := io.CopyBuffer(dst, src, ) + var buf []byte + // if h.BufferPool != nil { + // buf = h.BufferPool.Get() + // defer h.BufferPool.Put(buf) + // } + _, err := h.copyBuffer(dst, src, buf) + return err +} + +// copyBuffer returns any write errors or non-EOF read errors, and the amount +// of bytes written. +func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && rerr != io.EOF && rerr != context.Canceled { + // TODO: this could be useful to know (indeed, it revealed an error in our + // fastcgi PoC earlier; but it's this single error report here that necessitates + // a function separate from io.CopyBuffer, since io.CopyBuffer does not distinguish + // between read or write errors; in a reverse proxy situation, write errors are not + // something we need to report to the client, but read errors are a problem on our + // end for sure. so we need to decide what we want.) + // p.logf("copyBuffer: ReverseProxy read error during body copy: %v", rerr) + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + if rerr == io.EOF { + rerr = nil + } + return written, rerr + } + } +} + +type writeFlusher interface { + io.Writer + http.Flusher +} + +type maxLatencyWriter struct { + dst writeFlusher + latency time.Duration // non-zero; negative means to flush immediately + + mu sync.Mutex // protects t, flushPending, and dst.Flush + t *time.Timer + flushPending bool +} + +func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + n, err = m.dst.Write(p) + if m.latency < 0 { + m.dst.Flush() + return + } + if m.flushPending { + return + } + if m.t == nil { + m.t = time.AfterFunc(m.latency, m.delayedFlush) + } else { + m.t.Reset(m.latency) + } + m.flushPending = true + return +} + +func (m *maxLatencyWriter) delayedFlush() { + m.mu.Lock() + defer m.mu.Unlock() + if !m.flushPending { // if stop was called but AfterFunc already started this goroutine + return + } + m.dst.Flush() + m.flushPending = false +} + +func (m *maxLatencyWriter) stop() { + m.mu.Lock() + defer m.mu.Unlock() + m.flushPending = false + if m.t != nil { + m.t.Stop() + } +} + +// switchProtocolCopier exists so goroutines proxying data back and +// forth have nice names in stacks. +type switchProtocolCopier struct { + user, backend io.ReadWriter +} + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + _, err := io.Copy(c.user, c.backend) + errc <- err +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + _, err := io.Copy(c.backend, c.user) + errc <- err +} -- cgit v1.2.3 From acb8f0e0c26acd95cbee8981469b4ac62535d164 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Tue, 3 Sep 2019 19:06:54 -0600 Subject: Integrate circuit breaker modules with reverse proxy --- modules/caddyhttp/reverseproxy/healthchecks.go | 10 +++++++++- modules/caddyhttp/reverseproxy/hosts.go | 10 +++++++++- modules/caddyhttp/reverseproxy/reverseproxy.go | 18 ++++++++++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go index 0b46d04..673f7c4 100644 --- a/modules/caddyhttp/reverseproxy/healthchecks.go +++ b/modules/caddyhttp/reverseproxy/healthchecks.go @@ -64,6 +64,14 @@ type PassiveHealthChecks struct { UnhealthyLatency caddy.Duration `json:"unhealthy_latency,omitempty"` } +// CircuitBreaker is a type that can act as an early-warning +// system for the health checker when backends are getting +// overloaded. +type CircuitBreaker interface { + OK() bool + RecordMetric(statusCode int, latency time.Duration) +} + // activeHealthChecker runs active health checks on a // regular basis and blocks until // h.HealthChecks.Active.stopChan is closed. @@ -202,7 +210,7 @@ func (h *Handler) doActiveHealthCheck(hostAddr string, host Host) error { // remembers 1 failure for upstream for the configured // duration. If passive health checks are disabled or // failure expiry is 0, this is a no-op. -func (h Handler) countFailure(upstream *Upstream) { +func (h *Handler) countFailure(upstream *Upstream) { // only count failures if passive health checking is enabled // and if failures are configured have a non-zero expiry if h.HealthChecks == nil || h.HealthChecks.Passive == nil { diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go index 5100936..b40e614 100644 --- a/modules/caddyhttp/reverseproxy/hosts.go +++ b/modules/caddyhttp/reverseproxy/hosts.go @@ -69,21 +69,29 @@ type Upstream struct { healthCheckPolicy *PassiveHealthChecks hostURL *url.URL + cb CircuitBreaker } // Available returns true if the remote host -// is available to receive requests. +// is available to receive requests. This is +// the method that should be used by selection +// policies, etc. to determine if a backend +// should be able to be sent a request. func (u *Upstream) Available() bool { return u.Healthy() && !u.Full() } // Healthy returns true if the remote host // is currently known to be healthy or "up". +// It consults the circuit breaker, if any. func (u *Upstream) Healthy() bool { healthy := !u.Host.Unhealthy() if healthy && u.healthCheckPolicy != nil { healthy = u.Host.Fails() < u.healthCheckPolicy.MaxFails } + if healthy && u.cb != nil { + healthy = u.cb.OK() + } return healthy } diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index ca54741..16d7f7a 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -37,12 +37,14 @@ func init() { // Handler implements a highly configurable and production-ready reverse proxy. type Handler struct { TransportRaw json.RawMessage `json:"transport,omitempty"` + CBRaw json.RawMessage `json:"circuit_breaker,omitempty"` LoadBalancing *LoadBalancing `json:"load_balancing,omitempty"` HealthChecks *HealthChecks `json:"health_checks,omitempty"` Upstreams UpstreamPool `json:"upstreams,omitempty"` FlushInterval caddy.Duration `json:"flush_interval,omitempty"` Transport http.RoundTripper `json:"-"` + CB CircuitBreaker `json:"-"` } // CaddyModule returns the Caddy module information. @@ -55,6 +57,7 @@ func (Handler) CaddyModule() caddy.ModuleInfo { // Provision ensures that h is set up properly before use. func (h *Handler) Provision(ctx caddy.Context) error { + // start by loading modules if h.TransportRaw != nil { val, err := ctx.LoadModuleInline("protocol", "http.handlers.reverse_proxy.transport", h.TransportRaw) if err != nil { @@ -73,6 +76,14 @@ func (h *Handler) Provision(ctx caddy.Context) error { h.LoadBalancing.SelectionPolicy = val.(Selector) h.LoadBalancing.SelectionPolicyRaw = nil // allow GC to deallocate - TODO: Does this help? } + if h.CBRaw != nil { + val, err := ctx.LoadModuleInline("type", "http.handlers.reverse_proxy.circuit_breakers", h.CBRaw) + if err != nil { + return fmt.Errorf("loading circuit breaker module: %s", err) + } + h.CB = val.(CircuitBreaker) + h.CBRaw = nil // allow GC to deallocate - TODO: Does this help? + } if h.Transport == nil { h.Transport = defaultTransport @@ -123,6 +134,8 @@ func (h *Handler) Provision(ctx caddy.Context) error { } for _, upstream := range h.Upstreams { + upstream.cb = h.CB + // url parser requires a scheme if !strings.Contains(upstream.Address, "://") { upstream.Address = "http://" + upstream.Address @@ -307,6 +320,11 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, upstre return err } + // update circuit breaker on current conditions + if upstream.cb != nil { + upstream.cb.RecordMetric(res.StatusCode, latency) + } + // perform passive health checks (if enabled) if h.HealthChecks != nil && h.HealthChecks.Passive != nil { // strike if the status code matches one that is "bad" -- cgit v1.2.3 From a60d54dbfd93f74187b4051f1522c42d34480503 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Tue, 3 Sep 2019 19:10:09 -0600 Subject: reverse_proxy: Ignore context.Canceled errors These happen when downstream clients cancel the request, but that's not our problem nor a failure in our end --- modules/caddyhttp/reverseproxy/reverseproxy.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 16d7f7a..7bf9a2f 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -224,7 +224,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht // proxy the request to that upstream proxyErr = h.reverseProxy(w, r, upstream) - if proxyErr == nil { + if proxyErr == nil || proxyErr == context.Canceled { + // context.Canceled happens when the downstream client + // cancels the request; we don't have to worry about that return nil } -- cgit v1.2.3 From 0830fbad0347ead1dbea60e664556b263e44653f Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Thu, 5 Sep 2019 13:14:39 -0600 Subject: Reconcile upstream dial addresses and request host/URL information My goodness that was complicated Blessed be request.Context Sort of --- modules/caddyhttp/caddyhttp.go | 12 +- modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go | 102 ++++++----------- modules/caddyhttp/reverseproxy/healthchecks.go | 68 +++++++---- modules/caddyhttp/reverseproxy/hosts.go | 38 +++++-- modules/caddyhttp/reverseproxy/httptransport.go | 28 ++--- modules/caddyhttp/reverseproxy/reverseproxy.go | 130 ++++++++++++++-------- modules/caddyhttp/server.go | 2 +- 7 files changed, 212 insertions(+), 168 deletions(-) (limited to 'modules') diff --git a/modules/caddyhttp/caddyhttp.go b/modules/caddyhttp/caddyhttp.go index 300b5fd..174e316 100644 --- a/modules/caddyhttp/caddyhttp.go +++ b/modules/caddyhttp/caddyhttp.go @@ -108,7 +108,7 @@ func (app *App) Validate() error { lnAddrs := make(map[string]string) for srvName, srv := range app.Servers { for _, addr := range srv.Listen { - netw, expanded, err := caddy.ParseListenAddr(addr) + netw, expanded, err := caddy.ParseNetworkAddress(addr) if err != nil { return fmt.Errorf("invalid listener address '%s': %v", addr, err) } @@ -149,7 +149,7 @@ func (app *App) Start() error { } for _, lnAddr := range srv.Listen { - network, addrs, err := caddy.ParseListenAddr(lnAddr) + network, addrs, err := caddy.ParseNetworkAddress(lnAddr) if err != nil { return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err) } @@ -309,7 +309,7 @@ func (app *App) automaticHTTPS() error { // create HTTP->HTTPS redirects for _, addr := range srv.Listen { - netw, host, port, err := caddy.SplitListenAddr(addr) + netw, host, port, err := caddy.SplitNetworkAddress(addr) if err != nil { return fmt.Errorf("%s: invalid listener address: %v", srvName, addr) } @@ -318,7 +318,7 @@ func (app *App) automaticHTTPS() error { if httpPort == 0 { httpPort = DefaultHTTPPort } - httpRedirLnAddr := caddy.JoinListenAddr(netw, host, strconv.Itoa(httpPort)) + httpRedirLnAddr := caddy.JoinNetworkAddress(netw, host, strconv.Itoa(httpPort)) lnAddrMap[httpRedirLnAddr] = struct{}{} if parts := strings.SplitN(port, "-", 2); len(parts) == 2 { @@ -361,7 +361,7 @@ func (app *App) automaticHTTPS() error { var lnAddrs []string mapLoop: for addr := range lnAddrMap { - netw, addrs, err := caddy.ParseListenAddr(addr) + netw, addrs, err := caddy.ParseNetworkAddress(addr) if err != nil { continue } @@ -386,7 +386,7 @@ func (app *App) automaticHTTPS() error { func (app *App) listenerTaken(network, address string) bool { for _, srv := range app.Servers { for _, addr := range srv.Listen { - netw, addrs, err := caddy.ParseListenAddr(addr) + netw, addrs, err := caddy.ParseNetworkAddress(addr) if err != nil || netw != network { continue } diff --git a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go index 32f094b..35fef5f 100644 --- a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go +++ b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go @@ -25,6 +25,7 @@ import ( "strings" "time" + "github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy" "github.com/caddyserver/caddy/v2/modules/caddytls" "github.com/caddyserver/caddy/v2" @@ -34,6 +35,7 @@ func init() { caddy.RegisterModule(Transport{}) } +// Transport facilitates FastCGI communication. type Transport struct { ////////////////////////////// // TODO: taken from v1 Handler type @@ -57,32 +59,32 @@ type Transport struct { // Use this directory as the fastcgi root directory. Defaults to the root // directory of the parent virtual host. - Root string + Root string `json:"root,omitempty"` // The path in the URL will be split into two, with the first piece ending // with the value of SplitPath. The first piece will be assumed as the // actual resource (CGI script) name, and the second piece will be set to // PATH_INFO for the CGI script to use. - SplitPath string + SplitPath string `json:"split_path,omitempty"` // If the URL ends with '/' (which indicates a directory), these index // files will be tried instead. - IndexFiles []string + // IndexFiles []string // Environment Variables - EnvVars [][2]string + EnvVars [][2]string `json:"env,omitempty"` // Ignored paths - IgnoredSubPaths []string + // IgnoredSubPaths []string // The duration used to set a deadline when connecting to an upstream. - DialTimeout time.Duration + DialTimeout caddy.Duration `json:"dial_timeout,omitempty"` // The duration used to set a deadline when reading from the FastCGI server. - ReadTimeout time.Duration + ReadTimeout caddy.Duration `json:"read_timeout,omitempty"` // The duration used to set a deadline when sending to the FastCGI server. - WriteTimeout time.Duration + WriteTimeout caddy.Duration `json:"write_timeout,omitempty"` } // CaddyModule returns the Caddy module information. @@ -93,102 +95,62 @@ func (Transport) CaddyModule() caddy.ModuleInfo { } } +// RoundTrip implements http.RoundTripper. func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) { - // Create environment for CGI script env, err := t.buildEnv(r) if err != nil { return nil, fmt.Errorf("building environment: %v", err) } - // TODO: - // Connect to FastCGI gateway - // address, err := f.Address() - // if err != nil { - // return http.StatusBadGateway, err - // } - // network, address := parseAddress(address) - network, address := "tcp", r.URL.Host // TODO: - + // TODO: doesn't dialer have a Timeout field? ctx := context.Background() if t.DialTimeout > 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, t.DialTimeout) + ctx, cancel = context.WithTimeout(ctx, time.Duration(t.DialTimeout)) defer cancel() } + // extract dial information from request (this + // should embedded by the reverse proxy) + network, address := "tcp", r.URL.Host + if dialInfoVal := ctx.Value(reverseproxy.DialInfoCtxKey); dialInfoVal != nil { + dialInfo := dialInfoVal.(reverseproxy.DialInfo) + network = dialInfo.Network + address = dialInfo.Address + } + fcgiBackend, err := DialContext(ctx, network, address) if err != nil { return nil, fmt.Errorf("dialing backend: %v", err) } - // fcgiBackend is closed when response body is closed (see clientCloser) + // fcgiBackend gets closed when response body is closed (see clientCloser) // read/write timeouts - if err := fcgiBackend.SetReadTimeout(t.ReadTimeout); err != nil { + if err := fcgiBackend.SetReadTimeout(time.Duration(t.ReadTimeout)); err != nil { return nil, fmt.Errorf("setting read timeout: %v", err) } - if err := fcgiBackend.SetWriteTimeout(t.WriteTimeout); err != nil { + if err := fcgiBackend.SetWriteTimeout(time.Duration(t.WriteTimeout)); err != nil { return nil, fmt.Errorf("setting write timeout: %v", err) } - var resp *http.Response - - var contentLength int64 - // if ContentLength is already set - if r.ContentLength > 0 { - contentLength = r.ContentLength - } else { + contentLength := r.ContentLength + if contentLength == 0 { contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64) } + + var resp *http.Response switch r.Method { - case "HEAD": + case http.MethodHead: resp, err = fcgiBackend.Head(env) - case "GET": + case http.MethodGet: resp, err = fcgiBackend.Get(env, r.Body, contentLength) - case "OPTIONS": + case http.MethodOptions: resp, err = fcgiBackend.Options(env) default: resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength) } - // TODO: return resp, err - - // Stuff brought over from v1 that might not be necessary here: - - // if resp != nil && resp.Body != nil { - // defer resp.Body.Close() - // } - - // if err != nil { - // if err, ok := err.(net.Error); ok && err.Timeout() { - // return http.StatusGatewayTimeout, err - // } else if err != io.EOF { - // return http.StatusBadGateway, err - // } - // } - - // // Write response header - // writeHeader(w, resp) - - // // Write the response body - // _, err = io.Copy(w, resp.Body) - // if err != nil { - // return http.StatusBadGateway, err - // } - - // // Log any stderr output from upstream - // if fcgiBackend.stderr.Len() != 0 { - // // Remove trailing newline, error logger already does this. - // err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n")) - // } - - // // Normally we would return the status code if it is an error status (>= 400), - // // however, upstream FastCGI apps don't know about our contract and have - // // probably already written an error page. So we just return 0, indicating - // // that the response body is already written. However, we do return any - // // error value so it can be logged. - // // Note that the proxy middleware works the same way, returning status=0. - // return 0, err } // buildEnv returns a set of CGI environment variables for the request. diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go index 673f7c4..abe0f9c 100644 --- a/modules/caddyhttp/reverseproxy/healthchecks.go +++ b/modules/caddyhttp/reverseproxy/healthchecks.go @@ -15,6 +15,7 @@ package reverseproxy import ( + "context" "fmt" "io" "io/ioutil" @@ -93,15 +94,31 @@ func (h *Handler) activeHealthChecker() { // health checks for all hosts in the global repository. func (h *Handler) doActiveHealthChecksForAllHosts() { hosts.Range(func(key, value interface{}) bool { - addr := key.(string) + networkAddr := key.(string) host := value.(Host) - go func(addr string, host Host) { - err := h.doActiveHealthCheck(addr, host) + go func(networkAddr string, host Host) { + network, addrs, err := caddy.ParseNetworkAddress(networkAddr) if err != nil { - log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", addr, err) + log.Printf("[ERROR] reverse_proxy: active health check for host %s: bad network address: %v", networkAddr, err) + return } - }(addr, host) + if len(addrs) != 1 { + log.Printf("[ERROR] reverse_proxy: active health check for host %s: multiple addresses (upstream must map to only one address)", networkAddr) + return + } + hostAddr := addrs[0] + if network == "unix" || network == "unixgram" || network == "unixpacket" { + // this will be used as the Host portion of a http.Request URL, and + // paths to socket files would produce an error when creating URL, + // so use a fake Host value instead + hostAddr = network + } + err = h.doActiveHealthCheck(DialInfo{network, addrs[0]}, hostAddr, host) + if err != nil { + log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", networkAddr, err) + } + }(networkAddr, host) // continue to iterate all hosts return true @@ -115,26 +132,39 @@ func (h *Handler) doActiveHealthChecksForAllHosts() { // according to whether it passes the health check. An error is // returned only if the health check fails to occur or if marking // the host's health status fails. -func (h *Handler) doActiveHealthCheck(hostAddr string, host Host) error { - // create the URL for the health check - u, err := url.Parse(hostAddr) - if err != nil { - return err +func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host Host) error { + // create the URL for the request that acts as a health check + scheme := "http" + if ht, ok := h.Transport.(*http.Transport); ok && ht.TLSClientConfig != nil { + // this is kind of a hacky way to know if we should use HTTPS, but whatever + scheme = "https" } - if h.HealthChecks.Active.Path != "" { - u.Path = h.HealthChecks.Active.Path + u := &url.URL{ + Scheme: scheme, + Host: hostAddr, + Path: h.HealthChecks.Active.Path, } + + // adjust the port, if configured to be different if h.HealthChecks.Active.Port != 0 { portStr := strconv.Itoa(h.HealthChecks.Active.Port) - u.Host = net.JoinHostPort(u.Hostname(), portStr) + host, _, err := net.SplitHostPort(hostAddr) + if err != nil { + host = hostAddr + } + u.Host = net.JoinHostPort(host, portStr) } - req, err := http.NewRequest(http.MethodGet, u.String(), nil) + // attach dialing information to this request + ctx := context.Background() + ctx = context.WithValue(ctx, caddy.ReplacerCtxKey, caddy.NewReplacer()) + ctx = context.WithValue(ctx, DialInfoCtxKey, dialInfo) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) if err != nil { - return err + return fmt.Errorf("making request: %v", err) } - // do the request, careful to tame the response body + // do the request, being careful to tame the response body resp, err := h.HealthChecks.Active.httpClient.Do(req) if err != nil { log.Printf("[INFO] reverse_proxy: active health check: %s is down (HTTP request failed: %v)", hostAddr, err) @@ -149,7 +179,7 @@ func (h *Handler) doActiveHealthCheck(hostAddr string, host Host) error { body = io.LimitReader(body, h.HealthChecks.Active.MaxSize) } defer func() { - // drain any remaining body so connection can be re-used + // drain any remaining body so connection could be re-used io.Copy(ioutil.Discard, body) resp.Body.Close() }() @@ -225,7 +255,7 @@ func (h *Handler) countFailure(upstream *Upstream) { err := upstream.Host.CountFail(1) if err != nil { log.Printf("[ERROR] proxy: upstream %s: counting failure: %v", - upstream.hostURL, err) + upstream.dialInfo, err) } // forget it later @@ -234,7 +264,7 @@ func (h *Handler) countFailure(upstream *Upstream) { err := host.CountFail(-1) if err != nil { log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v", - upstream.hostURL, err) + upstream.dialInfo, err) } }(upstream.Host, failDuration) } diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go index b40e614..ad27625 100644 --- a/modules/caddyhttp/reverseproxy/hosts.go +++ b/modules/caddyhttp/reverseproxy/hosts.go @@ -16,7 +16,6 @@ package reverseproxy import ( "fmt" - "net/url" "sync/atomic" "github.com/caddyserver/caddy/v2" @@ -59,7 +58,7 @@ type UpstreamPool []*Upstream type Upstream struct { Host `json:"-"` - Address string `json:"address,omitempty"` + Dial string `json:"dial,omitempty"` MaxRequests int `json:"max_requests,omitempty"` // TODO: This could be really useful, to bind requests @@ -68,8 +67,8 @@ type Upstream struct { // IPAffinity string healthCheckPolicy *PassiveHealthChecks - hostURL *url.URL cb CircuitBreaker + dialInfo DialInfo } // Available returns true if the remote host @@ -101,11 +100,6 @@ func (u *Upstream) Full() bool { return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests } -// URL returns the upstream host's endpoint URL. -func (u *Upstream) URL() *url.URL { - return u.hostURL -} - // upstreamHost is the basic, in-memory representation // of the state of a remote host. It implements the // Host interface. @@ -162,6 +156,34 @@ func (uh *upstreamHost) SetHealthy(healthy bool) (bool, error) { return swapped, nil } +// DialInfo contains information needed to dial a +// connection to an upstream host. This information +// may be different than that which is represented +// in a URL (for example, unix sockets don't have +// a host that can be represented in a URL, but +// they certainly have a network name and address). +type DialInfo struct { + // The network to use. This should be one of the + // values that is accepted by net.Dial: + // https://golang.org/pkg/net/#Dial + Network string + + // The address to dial. Follows the same + // semantics and rules as net.Dial. + Address string +} + +// String returns the Caddy network address form +// by joining the network and address with a +// forward slash. +func (di DialInfo) String() string { + return di.Network + "/" + di.Address +} + +// DialInfoCtxKey is used to store a DialInfo +// in a context.Context. +const DialInfoCtxKey = caddy.CtxKey("dial_info") + // hosts is the global repository for hosts that are // currently in use by active configuration(s). This // allows the state of remote hosts to be preserved diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index d9dc457..6c1d9c8 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -15,6 +15,7 @@ package reverseproxy import ( + "context" "crypto/tls" "crypto/x509" "encoding/base64" @@ -63,14 +64,23 @@ func (HTTPTransport) CaddyModule() caddy.ModuleInfo { // Provision sets up h.RoundTripper with a http.Transport // that is ready to use. -func (h *HTTPTransport) Provision(ctx caddy.Context) error { +func (h *HTTPTransport) Provision(_ caddy.Context) error { dialer := &net.Dialer{ Timeout: time.Duration(h.DialTimeout), FallbackDelay: time.Duration(h.FallbackDelay), // TODO: Resolver } + rt := &http.Transport{ - DialContext: dialer.DialContext, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + // the proper dialing information should be embedded into the request's context + if dialInfoVal := ctx.Value(DialInfoCtxKey); dialInfoVal != nil { + dialInfo := dialInfoVal.(DialInfo) + network = dialInfo.Network + address = dialInfo.Address + } + return dialer.DialContext(ctx, network, address) + }, MaxConnsPerHost: h.MaxConnsPerHost, ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout), ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout), @@ -91,7 +101,6 @@ func (h *HTTPTransport) Provision(ctx caddy.Context) error { if h.KeepAlive != nil { dialer.KeepAlive = time.Duration(h.KeepAlive.ProbeInterval) - if enabled := h.KeepAlive.Enabled; enabled != nil { rt.DisableKeepAlives = !*enabled } @@ -191,16 +200,3 @@ type KeepAlive struct { MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"` IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle } - -var ( - defaultDialer = net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, - } - - defaultTransport = &http.Transport{ - DialContext: defaultDialer.DialContext, - TLSHandshakeTimeout: 5 * time.Second, - IdleConnTimeout: 2 * time.Minute, - } -) diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 7bf9a2f..5a37613 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -20,7 +20,6 @@ import ( "fmt" "net" "net/http" - "net/url" "regexp" "strings" "time" @@ -86,7 +85,18 @@ func (h *Handler) Provision(ctx caddy.Context) error { } if h.Transport == nil { - h.Transport = defaultTransport + t := &HTTPTransport{ + KeepAlive: &KeepAlive{ + ProbeInterval: caddy.Duration(30 * time.Second), + IdleConnTimeout: caddy.Duration(2 * time.Minute), + }, + DialTimeout: caddy.Duration(10 * time.Second), + } + err := t.Provision(ctx) + if err != nil { + return fmt.Errorf("provisioning default transport: %v", err) + } + h.Transport = t } if h.LoadBalancing == nil { @@ -133,51 +143,65 @@ func (h *Handler) Provision(ctx caddy.Context) error { go h.activeHealthChecker() } + var allUpstreams []*Upstream for _, upstream := range h.Upstreams { - upstream.cb = h.CB - - // url parser requires a scheme - if !strings.Contains(upstream.Address, "://") { - upstream.Address = "http://" + upstream.Address - } - u, err := url.Parse(upstream.Address) + // upstreams are allowed to map to only a single host, + // but an upstream's address may semantically represent + // multiple addresses, so make sure to handle each + // one in turn based on this one upstream config + network, addresses, err := caddy.ParseNetworkAddress(upstream.Dial) if err != nil { - return fmt.Errorf("invalid upstream address %s: %v", upstream.Address, err) - } - upstream.hostURL = u - - // if host already exists from a current config, - // use that instead; otherwise, add it - // TODO: make hosts modular, so that their state can be distributed in enterprise for example - // TODO: If distributed, the pool should be stored in storage... - var host Host = new(upstreamHost) - activeHost, loaded := hosts.LoadOrStore(u.String(), host) - if loaded { - host = activeHost.(Host) - } - upstream.Host = host - - // if the passive health checker has a non-zero "unhealthy - // request count" but the upstream has no MaxRequests set - // (they are the same thing, but one is a default value for - // for upstreams with a zero MaxRequests), copy the default - // value into this upstream, since the value in the upstream - // is what is used during availability checks - if h.HealthChecks != nil && - h.HealthChecks.Passive != nil && - h.HealthChecks.Passive.UnhealthyRequestCount > 0 && - upstream.MaxRequests == 0 { - upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount + return fmt.Errorf("parsing dial address: %v", err) } - if h.HealthChecks != nil { + for _, addr := range addresses { + // make a new upstream based on the original + // that has a singular dial address + upstreamCopy := *upstream + upstreamCopy.dialInfo = DialInfo{network, addr} + upstreamCopy.Dial = upstreamCopy.dialInfo.String() + upstreamCopy.cb = h.CB + + // if host already exists from a current config, + // use that instead; otherwise, add it + // TODO: make hosts modular, so that their state can be distributed in enterprise for example + // TODO: If distributed, the pool should be stored in storage... + var host Host = new(upstreamHost) + activeHost, loaded := hosts.LoadOrStore(upstreamCopy.Dial, host) + if loaded { + host = activeHost.(Host) + } + upstreamCopy.Host = host + + // if the passive health checker has a non-zero "unhealthy + // request count" but the upstream has no MaxRequests set + // (they are the same thing, but one is a default value for + // for upstreams with a zero MaxRequests), copy the default + // value into this upstream, since the value in the upstream + // (MaxRequests) is what is used during availability checks + if h.HealthChecks != nil && + h.HealthChecks.Passive != nil && + h.HealthChecks.Passive.UnhealthyRequestCount > 0 && + upstreamCopy.MaxRequests == 0 { + upstreamCopy.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount + } + // upstreams need independent access to the passive - // health check policy so they can, you know, passively - // do health checks - upstream.healthCheckPolicy = h.HealthChecks.Passive + // health check policy because they run outside of the + // scope of a request handler + if h.HealthChecks != nil { + upstreamCopy.healthCheckPolicy = h.HealthChecks.Passive + } + + allUpstreams = append(allUpstreams, &upstreamCopy) } } + // replace the unmarshaled upstreams (possible 1:many + // address mapping) with our list, which is mapped 1:1, + // thus may have expanded the original list + h.Upstreams = allUpstreams + return nil } @@ -192,7 +216,7 @@ func (h *Handler) Cleanup() error { // remove hosts from our config from the pool for _, upstream := range h.Upstreams { - hosts.Delete(upstream.hostURL.String()) + hosts.Delete(upstream.dialInfo.String()) } return nil @@ -222,6 +246,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht continue } + // attach to the request information about how to dial the upstream; + // this is necessary because the information cannot be sufficiently + // or satisfactorily represented in a URL + ctx := context.WithValue(r.Context(), DialInfoCtxKey, upstream.dialInfo) + r = r.WithContext(ctx) + // proxy the request to that upstream proxyErr = h.reverseProxy(w, r, upstream) if proxyErr == nil || proxyErr == context.Canceled { @@ -249,6 +279,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht // This assumes that no mutations of the request are performed // by h during or after proxying. func (h Handler) prepareRequest(req *http.Request) error { + // as a special (but very common) case, if the transport + // is HTTP, then ensure the request has the proper scheme + // because incoming requests by default are lacking it + if req.URL.Scheme == "" { + req.URL.Scheme = "http" + if ht, ok := h.Transport.(*HTTPTransport); ok && ht.TLS != nil { + req.URL.Scheme = "https" + } + } + if req.ContentLength == 0 { req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries } @@ -433,14 +473,8 @@ func (h Handler) tryAgain(start time.Time, proxyErr error) bool { // directRequest modifies only req.URL so that it points to the // given upstream host. It must modify ONLY the request URL. func (h Handler) directRequest(req *http.Request, upstream *Upstream) { - target := upstream.hostURL - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host - req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) // TODO: This might be a bug (if any part of the path was augmented from a previously-tried upstream; need to start from clean original path of request, same for query string!) - if target.RawQuery == "" || req.URL.RawQuery == "" { - req.URL.RawQuery = target.RawQuery + req.URL.RawQuery - } else { - req.URL.RawQuery = target.RawQuery + "&" + req.URL.RawQuery + if req.URL.Host == "" { + req.URL.Host = upstream.dialInfo.Address } } diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index f820f71..248e5f2 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -168,7 +168,7 @@ func (s *Server) enforcementHandler(w http.ResponseWriter, r *http.Request, next // listeners in s that use a port which is not otherPort. func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool { for _, lnAddr := range s.Listen { - _, addrs, err := caddy.ParseListenAddr(lnAddr) + _, addrs, err := caddy.ParseNetworkAddress(lnAddr) if err == nil { for _, a := range addrs { _, port, err := net.SplitHostPort(a) -- cgit v1.2.3 From 80b54f3b9d5e207316fb9e8f83dd1e90659b25d7 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Thu, 5 Sep 2019 13:36:42 -0600 Subject: Add original URI to request context; implement into fastcgi env --- modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go | 77 +++++++++-------------- modules/caddyhttp/server.go | 17 +++++ 2 files changed, 46 insertions(+), 48 deletions(-) (limited to 'modules') diff --git a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go index 35fef5f..090de25 100644 --- a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go +++ b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go @@ -19,12 +19,14 @@ import ( "crypto/tls" "fmt" "net/http" + "net/url" "path" "path/filepath" "strconv" "strings" "time" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy" "github.com/caddyserver/caddy/v2/modules/caddytls" @@ -37,25 +39,11 @@ func init() { // Transport facilitates FastCGI communication. type Transport struct { - ////////////////////////////// - // TODO: taken from v1 Handler type - - SoftwareName string - SoftwareVersion string - ServerName string - ServerPort string - - ////////////////////////// - // TODO: taken from v1 Rule type - - // The base path to match. Required. - // Path string - - // upstream load balancer - // balancer - - // Always process files with this extension with fastcgi. - // Ext string + // TODO: Populate these + softwareName string + softwareVersion string + serverName string + serverPort string // Use this directory as the fastcgi root directory. Defaults to the root // directory of the parent virtual host. @@ -67,16 +55,9 @@ type Transport struct { // PATH_INFO for the CGI script to use. SplitPath string `json:"split_path,omitempty"` - // If the URL ends with '/' (which indicates a directory), these index - // files will be tried instead. - // IndexFiles []string - // Environment Variables EnvVars [][2]string `json:"env,omitempty"` - // Ignored paths - // IgnoredSubPaths []string - // The duration used to set a deadline when connecting to an upstream. DialTimeout caddy.Duration `json:"dial_timeout,omitempty"` @@ -170,7 +151,6 @@ func (t Transport) buildEnv(r *http.Request) (map[string]string, error) { ip = strings.Replace(ip, "[", "", 1) ip = strings.Replace(ip, "]", "", 1) - // TODO: respect index files? or leave that to matcher/rewrite (I prefer that)? fpath := r.URL.Path // Split path in preparation for env variables. @@ -194,16 +174,17 @@ func (t Transport) buildEnv(r *http.Request) (map[string]string, error) { pathPrefix, _ := r.Context().Value(caddy.CtxKey("path_prefix")).(string) scriptName = path.Join(pathPrefix, scriptName) - // TODO: Disabled for now - // // Get the request URI from context. The context stores the original URI in case - // // it was changed by a middleware such as rewrite. By default, we pass the - // // original URI in as the value of REQUEST_URI (the user can overwrite this - // // if desired). Most PHP apps seem to want the original URI. Besides, this is - // // how nginx defaults: http://stackoverflow.com/a/12485156/1048862 - // reqURL, _ := r.Context().Value(httpserver.OriginalURLCtxKey).(url.URL) - - // // Retrieve name of remote user that was set by some downstream middleware such as basicauth. - // remoteUser, _ := r.Context().Value(httpserver.RemoteUserCtxKey).(string) + // Get the request URL from context. The context stores the original URL in case + // it was changed by a middleware such as rewrite. By default, we pass the + // original URI in as the value of REQUEST_URI (the user can overwrite this + // if desired). Most PHP apps seem to want the original URI. Besides, this is + // how nginx defaults: http://stackoverflow.com/a/12485156/1048862 + reqURL, ok := r.Context().Value(caddyhttp.OriginalURLCtxKey).(url.URL) + if !ok { + // some requests, like active health checks, don't add this to + // the request context, so we can just use the current URL + reqURL = *r.URL + } requestScheme := "http" if r.TLS != nil { @@ -224,19 +205,19 @@ func (t Transport) buildEnv(r *http.Request) (map[string]string, error) { "REMOTE_HOST": ip, // For speed, remote host lookups disabled "REMOTE_PORT": port, "REMOTE_IDENT": "", // Not used - // "REMOTE_USER": remoteUser, // TODO: - "REQUEST_METHOD": r.Method, - "REQUEST_SCHEME": requestScheme, - "SERVER_NAME": t.ServerName, - "SERVER_PORT": t.ServerPort, - "SERVER_PROTOCOL": r.Proto, - "SERVER_SOFTWARE": t.SoftwareName + "/" + t.SoftwareVersion, + "REMOTE_USER": "", // TODO: once there are authentication handlers, populate this + "REQUEST_METHOD": r.Method, + "REQUEST_SCHEME": requestScheme, + "SERVER_NAME": t.ServerName, + "SERVER_PORT": t.ServerPort, + "SERVER_PROTOCOL": r.Proto, + "SERVER_SOFTWARE": t.SoftwareName + "/" + t.SoftwareVersion, // Other variables - // "DOCUMENT_ROOT": rule.Root, - "DOCUMENT_URI": docURI, - "HTTP_HOST": r.Host, // added here, since not always part of headers - // "REQUEST_URI": reqURL.RequestURI(), // TODO: + "DOCUMENT_ROOT": t.Root, + "DOCUMENT_URI": docURI, + "HTTP_HOST": r.Host, // added here, since not always part of headers + "REQUEST_URI": reqURL.RequestURI(), "SCRIPT_FILENAME": scriptFilename, "SCRIPT_NAME": scriptName, } diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index 248e5f2..04935e6 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -20,6 +20,7 @@ import ( "log" "net" "net/http" + "net/url" "strconv" "strings" @@ -58,6 +59,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), caddy.ReplacerCtxKey, repl) ctx = context.WithValue(ctx, ServerCtxKey, s) ctx = context.WithValue(ctx, VarCtxKey, make(map[string]interface{})) + ctx = context.WithValue(ctx, OriginalURLCtxKey, cloneURL(r.URL)) r = r.WithContext(ctx) // once the pointer to the request won't change @@ -228,6 +230,18 @@ type HTTPErrorConfig struct { Routes RouteList `json:"routes,omitempty"` } +// cloneURL makes a copy of r.URL and returns a +// new value that doesn't reference the original. +func cloneURL(u *url.URL) url.URL { + urlCopy := *u + if u.User != nil { + userInfo := new(url.Userinfo) + *userInfo = *u.User + urlCopy.User = userInfo + } + return urlCopy +} + // Context keys for HTTP request context values. const ( // For referencing the server instance @@ -235,4 +249,7 @@ const ( // For the request's variable table VarCtxKey caddy.CtxKey = "vars" + + // For the unmodified URL that originally came in with a request + OriginalURLCtxKey caddy.CtxKey = "original_url" ) -- cgit v1.2.3 From d2e46c2be0c72cf49d87a9c70400ff65046e5123 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Thu, 5 Sep 2019 13:42:20 -0600 Subject: fastcgi: Set default root path; add interface guards --- modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go | 30 +++++++++++++++++------ modules/caddyhttp/reverseproxy/httptransport.go | 6 +++++ 2 files changed, 29 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go index 090de25..9d724c1 100644 --- a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go +++ b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go @@ -76,6 +76,14 @@ func (Transport) CaddyModule() caddy.ModuleInfo { } } +// Provision sets up t. +func (t *Transport) Provision(_ caddy.Context) error { + if t.Root == "" { + t.Root = "{http.var.root}" + } + return nil +} + // RoundTrip implements http.RoundTripper. func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) { env, err := t.buildEnv(r) @@ -136,6 +144,8 @@ func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) { // buildEnv returns a set of CGI environment variables for the request. func (t Transport) buildEnv(r *http.Request) (map[string]string, error) { + repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) + var env map[string]string // Separate remote IP and port; more lenient than net.SplitHostPort @@ -151,6 +161,7 @@ func (t Transport) buildEnv(r *http.Request) (map[string]string, error) { ip = strings.Replace(ip, "[", "", 1) ip = strings.Replace(ip, "]", "", 1) + root := repl.ReplaceAll(t.Root, ".") fpath := r.URL.Path // Split path in preparation for env variables. @@ -167,7 +178,7 @@ func (t Transport) buildEnv(r *http.Request) (map[string]string, error) { scriptName = strings.TrimSuffix(scriptName, pathInfo) // SCRIPT_FILENAME is the absolute path of SCRIPT_NAME - scriptFilename := filepath.Join(t.Root, scriptName) + scriptFilename := filepath.Join(root, scriptName) // Add vhost path prefix to scriptName. Otherwise, some PHP software will // have difficulty discovering its URL. @@ -208,13 +219,13 @@ func (t Transport) buildEnv(r *http.Request) (map[string]string, error) { "REMOTE_USER": "", // TODO: once there are authentication handlers, populate this "REQUEST_METHOD": r.Method, "REQUEST_SCHEME": requestScheme, - "SERVER_NAME": t.ServerName, - "SERVER_PORT": t.ServerPort, + "SERVER_NAME": t.serverName, + "SERVER_PORT": t.serverPort, "SERVER_PROTOCOL": r.Proto, - "SERVER_SOFTWARE": t.SoftwareName + "/" + t.SoftwareVersion, + "SERVER_SOFTWARE": t.softwareName + "/" + t.softwareVersion, // Other variables - "DOCUMENT_ROOT": t.Root, + "DOCUMENT_ROOT": root, "DOCUMENT_URI": docURI, "HTTP_HOST": r.Host, // added here, since not always part of headers "REQUEST_URI": reqURL.RequestURI(), @@ -226,7 +237,7 @@ func (t Transport) buildEnv(r *http.Request) (map[string]string, error) { // PATH_TRANSLATED should only exist if PATH_INFO is defined. // Info: https://www.ietf.org/rfc/rfc3875 Page 14 if env["PATH_INFO"] != "" { - env["PATH_TRANSLATED"] = filepath.Join(t.Root, pathInfo) // Info: http://www.oreilly.com/openbook/cgi/ch02_04.html + env["PATH_TRANSLATED"] = filepath.Join(root, pathInfo) // Info: http://www.oreilly.com/openbook/cgi/ch02_04.html } // Some web apps rely on knowing HTTPS or not @@ -248,7 +259,6 @@ func (t Transport) buildEnv(r *http.Request) (map[string]string, error) { } // Add env variables from config (with support for placeholders in values) - repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) for _, envVar := range t.EnvVars { env[envVar[0]] = repl.ReplaceAll(envVar[1], "") } @@ -283,3 +293,9 @@ var tlsProtocolStrings = map[uint16]string{ } var headerNameReplacer = strings.NewReplacer(" ", "_", "-", "_") + +// Interface guards +var ( + _ caddy.Provisioner = (*Transport)(nil) + _ http.RoundTripper = (*Transport)(nil) +) diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index 6c1d9c8..c135ac8 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -200,3 +200,9 @@ type KeepAlive struct { MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"` IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle } + +// Interface guards +var ( + _ caddy.Provisioner = (*HTTPTransport)(nil) + _ http.RoundTripper = (*HTTPTransport)(nil) +) -- cgit v1.2.3 From 21d7b662e76feeb506cae9a616d92d85326566bd Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Fri, 6 Sep 2019 12:02:11 -0600 Subject: fastcgi: Use request context as base, not a new one --- modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go index 9d724c1..0368fde 100644 --- a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go +++ b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go @@ -92,7 +92,7 @@ func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) { } // TODO: doesn't dialer have a Timeout field? - ctx := context.Background() + ctx := r.Context() if t.DialTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, time.Duration(t.DialTimeout)) -- cgit v1.2.3 From 14f9662f9cc0f93e88d5efbbaf10de79070bea93 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Fri, 6 Sep 2019 12:36:45 -0600 Subject: Various fixes/tweaks to HTTP placeholder variables and file matching - Rename http.var.* -> http.vars.* to be more consistent - Prefixing a path matcher with * now invokes simple suffix matching - Handlers and matchers that need a root path default to {http.vars.root} - Clean replacer output on the file matcher's file selection suffix --- modules/caddyhttp/fileserver/caddyfile.go | 8 +------- modules/caddyhttp/fileserver/matcher.go | 16 +++++++++++----- modules/caddyhttp/fileserver/staticfiles.go | 4 ++++ modules/caddyhttp/matchers.go | 8 ++++---- modules/caddyhttp/replacer.go | 2 +- modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go | 2 +- modules/caddyhttp/templates/caddyfile.go | 5 ----- modules/caddyhttp/templates/templates.go | 3 +++ 8 files changed, 25 insertions(+), 23 deletions(-) (limited to 'modules') diff --git a/modules/caddyhttp/fileserver/caddyfile.go b/modules/caddyhttp/fileserver/caddyfile.go index 7afcc9e..4622af2 100644 --- a/modules/caddyhttp/fileserver/caddyfile.go +++ b/modules/caddyhttp/fileserver/caddyfile.go @@ -17,9 +17,9 @@ package fileserver import ( "encoding/json" - "github.com/caddyserver/caddy/v2/modules/caddyhttp/rewrite" "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" "github.com/caddyserver/caddy/v2/modules/caddyhttp" + "github.com/caddyserver/caddy/v2/modules/caddyhttp/rewrite" ) func init() { @@ -71,11 +71,6 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) } } - // if no root was configured explicitly, use site root - if fsrv.Root == "" { - fsrv.Root = "{http.var.root}" - } - // hide the Caddyfile (and any imported Caddyfiles) if configFiles := h.Caddyfiles(); len(configFiles) > 0 { for _, file := range configFiles { @@ -104,7 +99,6 @@ func parseTryFiles(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error) matcherSet := map[string]json.RawMessage{ "file": h.JSON(MatchFile{ - Root: "{http.var.root}", TryFiles: try, }, nil), } diff --git a/modules/caddyhttp/fileserver/matcher.go b/modules/caddyhttp/fileserver/matcher.go index b091250..99e217e 100644 --- a/modules/caddyhttp/fileserver/matcher.go +++ b/modules/caddyhttp/fileserver/matcher.go @@ -18,6 +18,7 @@ import ( "fmt" "net/http" "os" + "path" "time" "github.com/caddyserver/caddy/v2" @@ -87,8 +88,13 @@ func (m *MatchFile) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } } } + return nil +} + +// Provision sets up m's defaults. +func (m *MatchFile) Provision(_ caddy.Context) error { if m.Root == "" { - m.Root = "{http.var.root}" + m.Root = "{http.vars.root}" } return nil } @@ -141,7 +147,7 @@ func (m MatchFile) selectFile(r *http.Request) (rel, abs string, matched bool) { switch m.TryPolicy { case "", tryPolicyFirstExist: for _, f := range m.TryFiles { - suffix := repl.ReplaceAll(f, "") + suffix := path.Clean(repl.ReplaceAll(f, "")) fullpath := sanitizedPathJoin(root, suffix) if fileExists(fullpath) { return suffix, fullpath, true @@ -153,7 +159,7 @@ func (m MatchFile) selectFile(r *http.Request) (rel, abs string, matched bool) { var largestFilename string var largestSuffix string for _, f := range m.TryFiles { - suffix := repl.ReplaceAll(f, "") + suffix := path.Clean(repl.ReplaceAll(f, "")) fullpath := sanitizedPathJoin(root, suffix) info, err := os.Stat(fullpath) if err == nil && info.Size() > largestSize { @@ -169,7 +175,7 @@ func (m MatchFile) selectFile(r *http.Request) (rel, abs string, matched bool) { var smallestFilename string var smallestSuffix string for _, f := range m.TryFiles { - suffix := repl.ReplaceAll(f, "") + suffix := path.Clean(repl.ReplaceAll(f, "")) fullpath := sanitizedPathJoin(root, suffix) info, err := os.Stat(fullpath) if err == nil && (smallestSize == 0 || info.Size() < smallestSize) { @@ -185,7 +191,7 @@ func (m MatchFile) selectFile(r *http.Request) (rel, abs string, matched bool) { var recentFilename string var recentSuffix string for _, f := range m.TryFiles { - suffix := repl.ReplaceAll(f, "") + suffix := path.Clean(repl.ReplaceAll(f, "")) fullpath := sanitizedPathJoin(root, suffix) info, err := os.Stat(fullpath) if err == nil && diff --git a/modules/caddyhttp/fileserver/staticfiles.go b/modules/caddyhttp/fileserver/staticfiles.go index cdac453..cfb79f8 100644 --- a/modules/caddyhttp/fileserver/staticfiles.go +++ b/modules/caddyhttp/fileserver/staticfiles.go @@ -57,6 +57,10 @@ func (FileServer) CaddyModule() caddy.ModuleInfo { // Provision sets up the static files responder. func (fsrv *FileServer) Provision(ctx caddy.Context) error { + if fsrv.Root == "" { + fsrv.Root = "{http.vars.root}" + } + if fsrv.IndexNames == nil { fsrv.IndexNames = defaultIndexNames } diff --git a/modules/caddyhttp/matchers.go b/modules/caddyhttp/matchers.go index 2ddefd0..7fc8aea 100644 --- a/modules/caddyhttp/matchers.go +++ b/modules/caddyhttp/matchers.go @@ -22,7 +22,6 @@ import ( "net/http" "net/textproto" "net/url" - "path" "path/filepath" "regexp" "strings" @@ -151,12 +150,13 @@ func (MatchPath) CaddyModule() caddy.ModuleInfo { // Match returns true if r matches m. func (m MatchPath) Match(r *http.Request) bool { for _, matchPath := range m { - compare := r.URL.Path + // as a special case, if the first character is a + // wildcard, treat it as a quick suffix match if strings.HasPrefix(matchPath, "*") { - compare = path.Base(compare) + return strings.HasSuffix(r.URL.Path, matchPath[1:]) } // can ignore error here because we can't handle it anyway - matches, _ := filepath.Match(matchPath, compare) + matches, _ := filepath.Match(matchPath, r.URL.Path) if matches { return true } diff --git a/modules/caddyhttp/replacer.go b/modules/caddyhttp/replacer.go index cc29789..f7f69a4 100644 --- a/modules/caddyhttp/replacer.go +++ b/modules/caddyhttp/replacer.go @@ -173,6 +173,6 @@ const ( cookieReplPrefix = "http.request.cookie." hostLabelReplPrefix = "http.request.host.labels." pathPartsReplPrefix = "http.request.uri.path." - varsReplPrefix = "http.var." + varsReplPrefix = "http.vars." respHeaderReplPrefix = "http.response.header." ) diff --git a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go index 0368fde..66779e4 100644 --- a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go +++ b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go @@ -79,7 +79,7 @@ func (Transport) CaddyModule() caddy.ModuleInfo { // Provision sets up t. func (t *Transport) Provision(_ caddy.Context) error { if t.Root == "" { - t.Root = "{http.var.root}" + t.Root = "{http.vars.root}" } return nil } diff --git a/modules/caddyhttp/templates/caddyfile.go b/modules/caddyhttp/templates/caddyfile.go index d948da0..1336a60 100644 --- a/modules/caddyhttp/templates/caddyfile.go +++ b/modules/caddyhttp/templates/caddyfile.go @@ -53,10 +53,5 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) } } } - - if t.IncludeRoot == "" { - t.IncludeRoot = "{http.var.root}" - } - return t, nil } diff --git a/modules/caddyhttp/templates/templates.go b/modules/caddyhttp/templates/templates.go index 1cd347c..05a2f63 100644 --- a/modules/caddyhttp/templates/templates.go +++ b/modules/caddyhttp/templates/templates.go @@ -50,6 +50,9 @@ func (t *Templates) Provision(ctx caddy.Context) error { if t.MIMETypes == nil { t.MIMETypes = defaultMIMETypes } + if t.IncludeRoot == "" { + t.IncludeRoot = "{http.vars.root}" + } return nil } -- cgit v1.2.3 From 4bd949652564eff8ccfc50f105f0d35f0af26402 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Fri, 6 Sep 2019 12:57:12 -0600 Subject: Fix Schrodinger's file existence check in file matcher See: https://stackoverflow.com/a/12518877/1048862 For example, trying to check the existence of "/www/index.php/index.php" fails but not with an os.IsNotExist()-type error. So we have to assume that a file that cannot be successfully stat'ed at all does not exist. --- modules/caddyhttp/fileserver/matcher.go | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/caddyhttp/fileserver/matcher.go b/modules/caddyhttp/fileserver/matcher.go index 99e217e..88ce1d0 100644 --- a/modules/caddyhttp/fileserver/matcher.go +++ b/modules/caddyhttp/fileserver/matcher.go @@ -207,10 +207,22 @@ func (m MatchFile) selectFile(r *http.Request) (rel, abs string, matched bool) { return } -// fileExists returns true if file exists. +// fileExists returns true if file exists, +// false if it doesn't, or false if there +// was any other error. func fileExists(file string) bool { _, err := os.Stat(file) - return !os.IsNotExist(err) + if err == nil { + return true + } else if os.IsNotExist(err) { + return false + } else { + // we don't know if it exists, + // so assume it doesn't, since + // there must have been some + // other error anyway + return false + } } const ( -- cgit v1.2.3 From 97ace2a39e058f435d4e6adbee874eaabb42e45d Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Fri, 6 Sep 2019 13:32:02 -0600 Subject: File matcher enforces trailing-slash convention to match dirs/files --- modules/caddyhttp/fileserver/matcher.go | 42 +++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/caddyhttp/fileserver/matcher.go b/modules/caddyhttp/fileserver/matcher.go index 88ce1d0..fde086e 100644 --- a/modules/caddyhttp/fileserver/matcher.go +++ b/modules/caddyhttp/fileserver/matcher.go @@ -19,6 +19,7 @@ import ( "net/http" "os" "path" + "strings" "time" "github.com/caddyserver/caddy/v2" @@ -149,7 +150,7 @@ func (m MatchFile) selectFile(r *http.Request) (rel, abs string, matched bool) { for _, f := range m.TryFiles { suffix := path.Clean(repl.ReplaceAll(f, "")) fullpath := sanitizedPathJoin(root, suffix) - if fileExists(fullpath) { + if strictFileExists(fullpath) { return suffix, fullpath, true } } @@ -207,22 +208,33 @@ func (m MatchFile) selectFile(r *http.Request) (rel, abs string, matched bool) { return } -// fileExists returns true if file exists, -// false if it doesn't, or false if there -// was any other error. -func fileExists(file string) bool { - _, err := os.Stat(file) - if err == nil { - return true - } else if os.IsNotExist(err) { - return false - } else { - // we don't know if it exists, - // so assume it doesn't, since - // there must have been some - // other error anyway +// strictFileExists returns true if file exists +// and matches the convention of the given file +// path. If the path ends in a forward slash, +// the file must also be a directory; if it does +// NOT end in a forward slash, the file must NOT +// be a directory. +func strictFileExists(file string) bool { + stat, err := os.Stat(file) + if err != nil { + // in reality, this can be any error + // such as permission or even obscure + // ones like "is not a directory" (when + // trying to stat a file within a file); + // in those cases we can't be sure if + // the file exists, so we just treat any + // error as if it does not exist; see + // https://stackoverflow.com/a/12518877/1048862 return false } + if strings.HasSuffix(file, "/") { + // by convention, file paths ending + // in a slash must be a directory + return stat.IsDir() + } + // by convention, file paths NOT ending + // in a slash must NOT be a directory + return !stat.IsDir() } const ( -- cgit v1.2.3 From f6126acf379963136a6caeb818296a7510abd532 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Fri, 6 Sep 2019 14:25:16 -0600 Subject: Header matchers: allow matching presence of header with empty list --- modules/caddyhttp/matchers.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/caddyhttp/matchers.go b/modules/caddyhttp/matchers.go index 7fc8aea..4d0eea5 100644 --- a/modules/caddyhttp/matchers.go +++ b/modules/caddyhttp/matchers.go @@ -271,8 +271,13 @@ func (m *MatchHeader) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { // Match returns true if r matches m. func (m MatchHeader) Match(r *http.Request) bool { for field, allowedFieldVals := range m { + actualFieldVals, fieldExists := r.Header[textproto.CanonicalMIMEHeaderKey(field)] + if allowedFieldVals != nil && len(allowedFieldVals) == 0 && fieldExists { + // a non-nil but empty list of allowed values means + // match if the header field exists at all + continue + } var match bool - actualFieldVals := r.Header[textproto.CanonicalMIMEHeaderKey(field)] fieldVals: for _, actualFieldVal := range actualFieldVals { for _, allowedFieldVal := range allowedFieldVals { @@ -625,8 +630,13 @@ func (rm ResponseMatcher) matchStatusCode(statusCode int) bool { func (rm ResponseMatcher) matchHeaders(hdr http.Header) bool { for field, allowedFieldVals := range rm.Headers { + actualFieldVals, fieldExists := hdr[textproto.CanonicalMIMEHeaderKey(field)] + if allowedFieldVals != nil && len(allowedFieldVals) == 0 && fieldExists { + // a non-nil but empty list of allowed values means + // match if the header field exists at all + continue + } var match bool - actualFieldVals := hdr[textproto.CanonicalMIMEHeaderKey(field)] fieldVals: for _, actualFieldVal := range actualFieldVals { for _, allowedFieldVal := range allowedFieldVals { -- cgit v1.2.3 From 50e62d06bcbc6b6486b382a22c633772443cfb6d Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Mon, 9 Sep 2019 12:23:27 -0600 Subject: reverse_proxy: Caddyfile integration (and fix blocks in Dispenser) --- modules/caddyhttp/reverseproxy/caddyfile.go | 486 +++++++++++++++++++++ .../caddyhttp/reverseproxy/fastcgi/caddyfile.go | 54 +++ 2 files changed, 540 insertions(+) create mode 100644 modules/caddyhttp/reverseproxy/caddyfile.go create mode 100644 modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go (limited to 'modules') diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go new file mode 100644 index 0000000..ffa3ca0 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/caddyfile.go @@ -0,0 +1,486 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reverseproxy + +import ( + "net/http" + "strconv" + "strings" + "time" + + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig" + "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" + "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" + "github.com/dustin/go-humanize" +) + +func init() { + httpcaddyfile.RegisterHandlerDirective("reverse_proxy", parseCaddyfile) +} + +func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) { + rp := new(Handler) + err := rp.UnmarshalCaddyfile(h.Dispenser) + if err != nil { + return nil, err + } + return rp, nil +} + +// UnmarshalCaddyfile sets up the handler from Caddyfile tokens. Syntax: +// +// reverse_proxy [] [] { +// # upstreams +// to +// +// # load balancing +// lb_policy [] +// lb_try_duration +// lb_try_interval +// +// # active health checking +// health_path +// health_port +// health_interval +// health_timeout +// health_status +// health_body +// +// # passive health checking +// max_fails +// fail_duration +// max_conns +// unhealthy_status +// unhealthy_latency +// +// # round trip +// transport { +// ... +// } +// } +// +func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + for d.Next() { + for _, up := range d.RemainingArgs() { + h.Upstreams = append(h.Upstreams, &Upstream{ + Dial: up, + }) + } + + for d.NextBlock() { + switch d.Val() { + case "to": + args := d.RemainingArgs() + if len(args) == 0 { + return d.ArgErr() + } + for _, up := range args { + h.Upstreams = append(h.Upstreams, &Upstream{ + Dial: up, + }) + } + + case "lb_policy": + if !d.NextArg() { + return d.ArgErr() + } + if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil { + return d.Err("load balancing selection policy already specified") + } + name := d.Val() + mod, err := caddy.GetModule("http.handlers.reverse_proxy.selection_policies." + name) + if err != nil { + return d.Errf("getting load balancing policy module '%s': %v", mod.Name, err) + } + unm, ok := mod.New().(caddyfile.Unmarshaler) + if !ok { + return d.Errf("load balancing policy module '%s' is not a Caddyfile unmarshaler", mod.Name) + } + err = unm.UnmarshalCaddyfile(d.NewFromNextTokens()) + if err != nil { + return err + } + sel, ok := unm.(Selector) + if !ok { + return d.Errf("module %s is not a Selector", mod.Name) + } + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + h.LoadBalancing.SelectionPolicyRaw = caddyconfig.JSONModuleObject(sel, "policy", name, nil) + + case "lb_try_duration": + if !d.NextArg() { + return d.ArgErr() + } + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad duration value %s: %v", d.Val(), err) + } + h.LoadBalancing.TryDuration = caddy.Duration(dur) + + case "lb_try_interval": + if !d.NextArg() { + return d.ArgErr() + } + if h.LoadBalancing == nil { + h.LoadBalancing = new(LoadBalancing) + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad interval value '%s': %v", d.Val(), err) + } + h.LoadBalancing.TryInterval = caddy.Duration(dur) + + case "health_path": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + h.HealthChecks.Active.Path = d.Val() + + case "health_port": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + portNum, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("bad port number '%s': %v", d.Val(), err) + } + h.HealthChecks.Active.Port = portNum + + case "health_interval": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad interval value %s: %v", d.Val(), err) + } + h.HealthChecks.Active.Interval = caddy.Duration(dur) + + case "health_timeout": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad timeout value %s: %v", d.Val(), err) + } + h.HealthChecks.Active.Timeout = caddy.Duration(dur) + + case "health_status": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + val := d.Val() + if len(val) == 3 && strings.HasSuffix(val, "xx") { + val = val[:1] + } + statusNum, err := strconv.Atoi(val[:1]) + if err != nil { + return d.Errf("bad status value '%s': %v", d.Val(), err) + } + h.HealthChecks.Active.ExpectStatus = statusNum + + case "health_body": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Active == nil { + h.HealthChecks.Active = new(ActiveHealthChecks) + } + h.HealthChecks.Active.ExpectBody = d.Val() + + case "max_fails": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + maxFails, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("invalid maximum fail count '%s': %v", d.Val(), err) + } + h.HealthChecks.Passive.MaxFails = maxFails + + case "fail_duration": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad duration value '%s': %v", d.Val(), err) + } + h.HealthChecks.Passive.FailDuration = caddy.Duration(dur) + + case "unhealthy_request_count": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + maxConns, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("invalid maximum connection count '%s': %v", d.Val(), err) + } + h.HealthChecks.Passive.UnhealthyRequestCount = maxConns + + case "unhealthy_status": + args := d.RemainingArgs() + if len(args) == 0 { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + for _, arg := range args { + if len(arg) == 3 && strings.HasSuffix(arg, "xx") { + arg = arg[:1] + } + statusNum, err := strconv.Atoi(arg[:1]) + if err != nil { + return d.Errf("bad status value '%s': %v", d.Val(), err) + } + h.HealthChecks.Passive.UnhealthyStatus = append(h.HealthChecks.Passive.UnhealthyStatus, statusNum) + } + + case "unhealthy_latency": + if !d.NextArg() { + return d.ArgErr() + } + if h.HealthChecks == nil { + h.HealthChecks = new(HealthChecks) + } + if h.HealthChecks.Passive == nil { + h.HealthChecks.Passive = new(PassiveHealthChecks) + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad duration value '%s': %v", d.Val(), err) + } + h.HealthChecks.Passive.UnhealthyLatency = caddy.Duration(dur) + + case "transport": + if !d.NextArg() { + return d.ArgErr() + } + if h.TransportRaw != nil { + return d.Err("transport already specified") + } + name := d.Val() + mod, err := caddy.GetModule("http.handlers.reverse_proxy.transport." + name) + if err != nil { + return d.Errf("getting transport module '%s': %v", mod.Name, err) + } + unm, ok := mod.New().(caddyfile.Unmarshaler) + if !ok { + return d.Errf("transport module '%s' is not a Caddyfile unmarshaler", mod.Name) + } + d.Next() // consume the module name token + err = unm.UnmarshalCaddyfile(d.NewFromNextTokens()) + if err != nil { + return err + } + rt, ok := unm.(http.RoundTripper) + if !ok { + return d.Errf("module %s is not a RoundTripper", mod.Name) + } + h.TransportRaw = caddyconfig.JSONModuleObject(rt, "protocol", name, nil) + + default: + return d.Errf("unrecognized subdirective %s", d.Val()) + } + } + } + + return nil +} + +// UnmarshalCaddyfile deserializes Caddyfile tokens into h. +// +// transport http { +// read_buffer +// write_buffer +// dial_timeout +// tls_client_auth +// tls_insecure_skip_verify +// tls_timeout +// keepalive [off|] +// keepalive_idle_conns +// } +// +func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + for d.NextBlock() { + switch d.Val() { + case "read_buffer": + if !d.NextArg() { + return d.ArgErr() + } + size, err := humanize.ParseBytes(d.Val()) + if err != nil { + return d.Errf("invalid read buffer size '%s': %v", d.Val(), err) + } + h.ReadBufferSize = int(size) + + case "write_buffer": + if !d.NextArg() { + return d.ArgErr() + } + size, err := humanize.ParseBytes(d.Val()) + if err != nil { + return d.Errf("invalid write buffer size '%s': %v", d.Val(), err) + } + h.WriteBufferSize = int(size) + + case "dial_timeout": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad timeout value '%s': %v", d.Val(), err) + } + h.DialTimeout = caddy.Duration(dur) + + case "tls_client_auth": + args := d.RemainingArgs() + if len(args) != 2 { + return d.ArgErr() + } + if h.TLS == nil { + h.TLS = new(TLSConfig) + } + h.TLS.ClientCertificateFile = args[0] + h.TLS.ClientCertificateKeyFile = args[1] + + case "tls_insecure_skip_verify": + if d.NextArg() { + return d.ArgErr() + } + if h.TLS == nil { + h.TLS = new(TLSConfig) + } + h.TLS.InsecureSkipVerify = true + + case "tls_timeout": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad timeout value '%s': %v", d.Val(), err) + } + if h.TLS == nil { + h.TLS = new(TLSConfig) + } + h.TLS.HandshakeTimeout = caddy.Duration(dur) + + case "keepalive": + if !d.NextArg() { + return d.ArgErr() + } + if h.KeepAlive == nil { + h.KeepAlive = new(KeepAlive) + } + if d.Val() == "off" { + var disable bool + h.KeepAlive.Enabled = &disable + } + dur, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("bad duration value '%s': %v", d.Val(), err) + } + h.KeepAlive.IdleConnTimeout = caddy.Duration(dur) + + case "keepalive_idle_conns": + if !d.NextArg() { + return d.ArgErr() + } + num, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("bad integer value '%s': %v", d.Val(), err) + } + if h.KeepAlive == nil { + h.KeepAlive = new(KeepAlive) + } + h.KeepAlive.MaxIdleConns = num + + default: + return d.Errf("unrecognized subdirective %s", d.Val()) + } + } + return nil +} + +// Interface guards +var ( + _ caddyfile.Unmarshaler = (*Handler)(nil) + _ caddyfile.Unmarshaler = (*HTTPTransport)(nil) +) diff --git a/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go b/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go new file mode 100644 index 0000000..c8b9f63 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/fastcgi/caddyfile.go @@ -0,0 +1,54 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fastcgi + +import "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" + +// UnmarshalCaddyfile deserializes Caddyfile tokens into h. +// +// transport fastcgi { +// root +// split +// env +// } +// +func (t *Transport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + for d.NextBlock() { + switch d.Val() { + case "root": + if !d.NextArg() { + return d.ArgErr() + } + t.Root = d.Val() + + case "split": + if !d.NextArg() { + return d.ArgErr() + } + t.SplitPath = d.Val() + + case "env": + args := d.RemainingArgs() + if len(args) != 2 { + return d.ArgErr() + } + t.EnvVars = append(t.EnvVars, [2]string{args[0], args[1]}) + + default: + return d.Errf("unrecognized subdirective %s", d.Val()) + } + } + return nil +} -- cgit v1.2.3 From b4f4fcd437c2f9816f9511217bde703679808679 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Mon, 9 Sep 2019 21:44:58 -0600 Subject: Migrate some selection policy tests over to v2 --- modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go | 2 +- modules/caddyhttp/reverseproxy/hosts.go | 20 +- .../caddyhttp/reverseproxy/selectionpolicies.go | 6 +- .../reverseproxy/selectionpolicies_test.go | 604 +++++++++------------ 4 files changed, 272 insertions(+), 360 deletions(-) (limited to 'modules') diff --git a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go index 66779e4..91039c9 100644 --- a/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go +++ b/modules/caddyhttp/reverseproxy/fastcgi/fastcgi.go @@ -55,7 +55,7 @@ type Transport struct { // PATH_INFO for the CGI script to use. SplitPath string `json:"split_path,omitempty"` - // Environment Variables + // Environment variables (TODO: make a map of string to value...?) EnvVars [][2]string `json:"env,omitempty"` // The duration used to set a deadline when connecting to an upstream. diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go index ad27625..1c0fae3 100644 --- a/modules/caddyhttp/reverseproxy/hosts.go +++ b/modules/caddyhttp/reverseproxy/hosts.go @@ -34,19 +34,21 @@ type Host interface { // Unhealthy returns true if the backend is unhealthy. Unhealthy() bool - // CountRequest counts the given number of requests - // as currently in process with the host. The count - // should not go below 0. + // CountRequest atomically counts the given number of + // requests as currently in process with the host. The + // count should not go below 0. CountRequest(int) error - // CountFail counts the given number of failures - // with the host. The count should not go below 0. + // CountFail atomically counts the given number of + // failures with the host. The count should not go + // below 0. CountFail(int) error - // SetHealthy marks the host as either healthy (true) - // or unhealthy (false). If the given status is the - // same, this should be a no-op. It returns true if - // the given status was different, false otherwise. + // SetHealthy atomically marks the host as either + // healthy (true) or unhealthy (false). If the given + // status is the same, this should be a no-op and + // return false. It returns true if the status was + // changed; i.e. if it is now different from before. SetHealthy(bool) (bool, error) } diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go index 9680583..5bb2d62 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go @@ -82,7 +82,7 @@ type RandomChoiceSelection struct { // CaddyModule returns the Caddy module information. func (RandomChoiceSelection) CaddyModule() caddy.ModuleInfo { return caddy.ModuleInfo{ - Name: "http.handlers.reverse_proxy.selection_policies.random_choice", + Name: "http.handlers.reverse_proxy.selection_policies.random_choose", New: func() caddy.Module { return new(RandomChoiceSelection) }, } } @@ -147,14 +147,14 @@ func (LeastConnSelection) CaddyModule() caddy.ModuleInfo { func (LeastConnSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream { var bestHost *Upstream var count int - var leastReqs int + leastReqs := -1 for _, host := range pool { if !host.Available() { continue } numReqs := host.NumRequests() - if numReqs < leastReqs { + if leastReqs == -1 || numReqs < leastReqs { leastReqs = numReqs count = 0 } diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go index 8006fb1..e9939d6 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go @@ -14,350 +14,260 @@ package reverseproxy -// TODO: finish migrating these - -// import ( -// "net/http" -// "net/http/httptest" -// "os" -// "testing" -// ) - -// var workableServer *httptest.Server - -// func TestMain(m *testing.M) { -// workableServer = httptest.NewServer(http.HandlerFunc( -// func(w http.ResponseWriter, r *http.Request) { -// // do nothing -// })) -// r := m.Run() -// workableServer.Close() -// os.Exit(r) -// } - -// type customPolicy struct{} - -// func (customPolicy) Select(pool HostPool, _ *http.Request) Host { -// return pool[0] -// } - -// func testPool() HostPool { -// pool := []*UpstreamHost{ -// { -// Name: workableServer.URL, // this should resolve (healthcheck test) -// }, -// { -// Name: "http://localhost:99998", // this shouldn't -// }, -// { -// Name: "http://C", -// }, -// } -// return HostPool(pool) -// } - -// func TestRoundRobinPolicy(t *testing.T) { -// pool := testPool() -// rrPolicy := &RoundRobin{} -// request, _ := http.NewRequest("GET", "/", nil) - -// h := rrPolicy.Select(pool, request) -// // First selected host is 1, because counter starts at 0 -// // and increments before host is selected -// if h != pool[1] { -// t.Error("Expected first round robin host to be second host in the pool.") -// } -// h = rrPolicy.Select(pool, request) -// if h != pool[2] { -// t.Error("Expected second round robin host to be third host in the pool.") -// } -// h = rrPolicy.Select(pool, request) -// if h != pool[0] { -// t.Error("Expected third round robin host to be first host in the pool.") -// } -// // mark host as down -// pool[1].Unhealthy = 1 -// h = rrPolicy.Select(pool, request) -// if h != pool[2] { -// t.Error("Expected to skip down host.") -// } -// // mark host as up -// pool[1].Unhealthy = 0 - -// h = rrPolicy.Select(pool, request) -// if h == pool[2] { -// t.Error("Expected to balance evenly among healthy hosts") -// } -// // mark host as full -// pool[1].Conns = 1 -// pool[1].MaxConns = 1 -// h = rrPolicy.Select(pool, request) -// if h != pool[2] { -// t.Error("Expected to skip full host.") -// } -// } - -// func TestLeastConnPolicy(t *testing.T) { -// pool := testPool() -// lcPolicy := &LeastConn{} -// request, _ := http.NewRequest("GET", "/", nil) - -// pool[0].Conns = 10 -// pool[1].Conns = 10 -// h := lcPolicy.Select(pool, request) -// if h != pool[2] { -// t.Error("Expected least connection host to be third host.") -// } -// pool[2].Conns = 100 -// h = lcPolicy.Select(pool, request) -// if h != pool[0] && h != pool[1] { -// t.Error("Expected least connection host to be first or second host.") -// } -// } - -// func TestCustomPolicy(t *testing.T) { -// pool := testPool() -// customPolicy := &customPolicy{} -// request, _ := http.NewRequest("GET", "/", nil) - -// h := customPolicy.Select(pool, request) -// if h != pool[0] { -// t.Error("Expected custom policy host to be the first host.") -// } -// } - -// func TestIPHashPolicy(t *testing.T) { -// pool := testPool() -// ipHash := &IPHash{} -// request, _ := http.NewRequest("GET", "/", nil) -// // We should be able to predict where every request is routed. -// request.RemoteAddr = "172.0.0.1:80" -// h := ipHash.Select(pool, request) -// if h != pool[1] { -// t.Error("Expected ip hash policy host to be the second host.") -// } -// request.RemoteAddr = "172.0.0.2:80" -// h = ipHash.Select(pool, request) -// if h != pool[1] { -// t.Error("Expected ip hash policy host to be the second host.") -// } -// request.RemoteAddr = "172.0.0.3:80" -// h = ipHash.Select(pool, request) -// if h != pool[2] { -// t.Error("Expected ip hash policy host to be the third host.") -// } -// request.RemoteAddr = "172.0.0.4:80" -// h = ipHash.Select(pool, request) -// if h != pool[1] { -// t.Error("Expected ip hash policy host to be the second host.") -// } - -// // we should get the same results without a port -// request.RemoteAddr = "172.0.0.1" -// h = ipHash.Select(pool, request) -// if h != pool[1] { -// t.Error("Expected ip hash policy host to be the second host.") -// } -// request.RemoteAddr = "172.0.0.2" -// h = ipHash.Select(pool, request) -// if h != pool[1] { -// t.Error("Expected ip hash policy host to be the second host.") -// } -// request.RemoteAddr = "172.0.0.3" -// h = ipHash.Select(pool, request) -// if h != pool[2] { -// t.Error("Expected ip hash policy host to be the third host.") -// } -// request.RemoteAddr = "172.0.0.4" -// h = ipHash.Select(pool, request) -// if h != pool[1] { -// t.Error("Expected ip hash policy host to be the second host.") -// } - -// // we should get a healthy host if the original host is unhealthy and a -// // healthy host is available -// request.RemoteAddr = "172.0.0.1" -// pool[1].Unhealthy = 1 -// h = ipHash.Select(pool, request) -// if h != pool[2] { -// t.Error("Expected ip hash policy host to be the third host.") -// } - -// request.RemoteAddr = "172.0.0.2" -// h = ipHash.Select(pool, request) -// if h != pool[2] { -// t.Error("Expected ip hash policy host to be the third host.") -// } -// pool[1].Unhealthy = 0 - -// request.RemoteAddr = "172.0.0.3" -// pool[2].Unhealthy = 1 -// h = ipHash.Select(pool, request) -// if h != pool[0] { -// t.Error("Expected ip hash policy host to be the first host.") -// } -// request.RemoteAddr = "172.0.0.4" -// h = ipHash.Select(pool, request) -// if h != pool[1] { -// t.Error("Expected ip hash policy host to be the second host.") -// } - -// // We should be able to resize the host pool and still be able to predict -// // where a request will be routed with the same IP's used above -// pool = []*UpstreamHost{ -// { -// Name: workableServer.URL, // this should resolve (healthcheck test) -// }, -// { -// Name: "http://localhost:99998", // this shouldn't -// }, -// } -// pool = HostPool(pool) -// request.RemoteAddr = "172.0.0.1:80" -// h = ipHash.Select(pool, request) -// if h != pool[0] { -// t.Error("Expected ip hash policy host to be the first host.") -// } -// request.RemoteAddr = "172.0.0.2:80" -// h = ipHash.Select(pool, request) -// if h != pool[1] { -// t.Error("Expected ip hash policy host to be the second host.") -// } -// request.RemoteAddr = "172.0.0.3:80" -// h = ipHash.Select(pool, request) -// if h != pool[0] { -// t.Error("Expected ip hash policy host to be the first host.") -// } -// request.RemoteAddr = "172.0.0.4:80" -// h = ipHash.Select(pool, request) -// if h != pool[1] { -// t.Error("Expected ip hash policy host to be the second host.") -// } - -// // We should get nil when there are no healthy hosts -// pool[0].Unhealthy = 1 -// pool[1].Unhealthy = 1 -// h = ipHash.Select(pool, request) -// if h != nil { -// t.Error("Expected ip hash policy host to be nil.") -// } -// } - -// func TestFirstPolicy(t *testing.T) { -// pool := testPool() -// firstPolicy := &First{} -// req := httptest.NewRequest(http.MethodGet, "/", nil) - -// h := firstPolicy.Select(pool, req) -// if h != pool[0] { -// t.Error("Expected first policy host to be the first host.") -// } - -// pool[0].Unhealthy = 1 -// h = firstPolicy.Select(pool, req) -// if h != pool[1] { -// t.Error("Expected first policy host to be the second host.") -// } -// } - -// func TestUriPolicy(t *testing.T) { -// pool := testPool() -// uriPolicy := &URIHash{} - -// request := httptest.NewRequest(http.MethodGet, "/test", nil) -// h := uriPolicy.Select(pool, request) -// if h != pool[0] { -// t.Error("Expected uri policy host to be the first host.") -// } - -// pool[0].Unhealthy = 1 -// h = uriPolicy.Select(pool, request) -// if h != pool[1] { -// t.Error("Expected uri policy host to be the first host.") -// } - -// request = httptest.NewRequest(http.MethodGet, "/test_2", nil) -// h = uriPolicy.Select(pool, request) -// if h != pool[1] { -// t.Error("Expected uri policy host to be the second host.") -// } - -// // We should be able to resize the host pool and still be able to predict -// // where a request will be routed with the same URI's used above -// pool = []*UpstreamHost{ -// { -// Name: workableServer.URL, // this should resolve (healthcheck test) -// }, -// { -// Name: "http://localhost:99998", // this shouldn't -// }, -// } - -// request = httptest.NewRequest(http.MethodGet, "/test", nil) -// h = uriPolicy.Select(pool, request) -// if h != pool[0] { -// t.Error("Expected uri policy host to be the first host.") -// } - -// pool[0].Unhealthy = 1 -// h = uriPolicy.Select(pool, request) -// if h != pool[1] { -// t.Error("Expected uri policy host to be the first host.") -// } - -// request = httptest.NewRequest(http.MethodGet, "/test_2", nil) -// h = uriPolicy.Select(pool, request) -// if h != pool[1] { -// t.Error("Expected uri policy host to be the second host.") -// } - -// pool[0].Unhealthy = 1 -// pool[1].Unhealthy = 1 -// h = uriPolicy.Select(pool, request) -// if h != nil { -// t.Error("Expected uri policy policy host to be nil.") -// } -// } - -// func TestHeaderPolicy(t *testing.T) { -// pool := testPool() -// tests := []struct { -// Name string -// Policy *Header -// RequestHeaderName string -// RequestHeaderValue string -// NilHost bool -// HostIndex int -// }{ -// {"empty config", &Header{""}, "", "", true, 0}, -// {"empty config+header+value", &Header{""}, "Affinity", "somevalue", true, 0}, -// {"empty config+header", &Header{""}, "Affinity", "", true, 0}, - -// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 1}, -// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 2}, -// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 0}, - -// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue", false, 1}, -// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue2", false, 0}, -// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue3", false, 2}, -// {"hash route with empty value", &Header{"Affinity"}, "Affinity", "", false, 1}, -// } - -// for idx, test := range tests { -// request, _ := http.NewRequest("GET", "/", nil) -// if test.RequestHeaderName != "" { -// request.Header.Add(test.RequestHeaderName, test.RequestHeaderValue) -// } - -// host := test.Policy.Select(pool, request) -// if test.NilHost && host != nil { -// t.Errorf("%d: Expected host to be nil", idx) -// } -// if !test.NilHost && host == nil { -// t.Errorf("%d: Did not expect host to be nil", idx) -// } -// if !test.NilHost && host != pool[test.HostIndex] { -// t.Errorf("%d: Expected Header policy to be host %d", idx, test.HostIndex) -// } -// } -// } +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func testPool() UpstreamPool { + return UpstreamPool{ + {Host: new(upstreamHost)}, + {Host: new(upstreamHost)}, + {Host: new(upstreamHost)}, + } +} + +func TestRoundRobinPolicy(t *testing.T) { + pool := testPool() + rrPolicy := new(RoundRobinSelection) + req, _ := http.NewRequest("GET", "/", nil) + + h := rrPolicy.Select(pool, req) + // First selected host is 1, because counter starts at 0 + // and increments before host is selected + if h != pool[1] { + t.Error("Expected first round robin host to be second host in the pool.") + } + h = rrPolicy.Select(pool, req) + if h != pool[2] { + t.Error("Expected second round robin host to be third host in the pool.") + } + h = rrPolicy.Select(pool, req) + if h != pool[0] { + t.Error("Expected third round robin host to be first host in the pool.") + } + // mark host as down + pool[1].SetHealthy(false) + h = rrPolicy.Select(pool, req) + if h != pool[2] { + t.Error("Expected to skip down host.") + } + // mark host as up + pool[1].SetHealthy(true) + + h = rrPolicy.Select(pool, req) + if h == pool[2] { + t.Error("Expected to balance evenly among healthy hosts") + } + // mark host as full + pool[1].CountRequest(1) + pool[1].MaxRequests = 1 + h = rrPolicy.Select(pool, req) + if h != pool[2] { + t.Error("Expected to skip full host.") + } +} + +func TestLeastConnPolicy(t *testing.T) { + pool := testPool() + lcPolicy := new(LeastConnSelection) + req, _ := http.NewRequest("GET", "/", nil) + + pool[0].CountRequest(10) + pool[1].CountRequest(10) + h := lcPolicy.Select(pool, req) + if h != pool[2] { + t.Error("Expected least connection host to be third host.") + } + pool[2].CountRequest(100) + h = lcPolicy.Select(pool, req) + if h != pool[0] && h != pool[1] { + t.Error("Expected least connection host to be first or second host.") + } +} + +func TestIPHashPolicy(t *testing.T) { + pool := testPool() + ipHash := new(IPHashSelection) + req, _ := http.NewRequest("GET", "/", nil) + + // We should be able to predict where every request is routed. + req.RemoteAddr = "172.0.0.1:80" + h := ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + req.RemoteAddr = "172.0.0.2:80" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + req.RemoteAddr = "172.0.0.3:80" + h = ipHash.Select(pool, req) + if h != pool[2] { + t.Error("Expected ip hash policy host to be the third host.") + } + req.RemoteAddr = "172.0.0.4:80" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + + // we should get the same results without a port + req.RemoteAddr = "172.0.0.1" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + req.RemoteAddr = "172.0.0.2" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + req.RemoteAddr = "172.0.0.3" + h = ipHash.Select(pool, req) + if h != pool[2] { + t.Error("Expected ip hash policy host to be the third host.") + } + req.RemoteAddr = "172.0.0.4" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + + // we should get a healthy host if the original host is unhealthy and a + // healthy host is available + req.RemoteAddr = "172.0.0.1" + pool[1].SetHealthy(false) + h = ipHash.Select(pool, req) + if h != pool[2] { + t.Error("Expected ip hash policy host to be the third host.") + } + + req.RemoteAddr = "172.0.0.2" + h = ipHash.Select(pool, req) + if h != pool[2] { + t.Error("Expected ip hash policy host to be the third host.") + } + pool[1].SetHealthy(true) + + req.RemoteAddr = "172.0.0.3" + pool[2].SetHealthy(false) + h = ipHash.Select(pool, req) + if h != pool[0] { + t.Error("Expected ip hash policy host to be the first host.") + } + req.RemoteAddr = "172.0.0.4" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + + // We should be able to resize the host pool and still be able to predict + // where a req will be routed with the same IP's used above + pool = UpstreamPool{ + {Host: new(upstreamHost)}, + {Host: new(upstreamHost)}, + } + req.RemoteAddr = "172.0.0.1:80" + h = ipHash.Select(pool, req) + if h != pool[0] { + t.Error("Expected ip hash policy host to be the first host.") + } + req.RemoteAddr = "172.0.0.2:80" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + req.RemoteAddr = "172.0.0.3:80" + h = ipHash.Select(pool, req) + if h != pool[0] { + t.Error("Expected ip hash policy host to be the first host.") + } + req.RemoteAddr = "172.0.0.4:80" + h = ipHash.Select(pool, req) + if h != pool[1] { + t.Error("Expected ip hash policy host to be the second host.") + } + + // We should get nil when there are no healthy hosts + pool[0].SetHealthy(false) + pool[1].SetHealthy(false) + h = ipHash.Select(pool, req) + if h != nil { + t.Error("Expected ip hash policy host to be nil.") + } +} + +func TestFirstPolicy(t *testing.T) { + pool := testPool() + firstPolicy := new(FirstSelection) + req := httptest.NewRequest(http.MethodGet, "/", nil) + + h := firstPolicy.Select(pool, req) + if h != pool[0] { + t.Error("Expected first policy host to be the first host.") + } + + pool[0].SetHealthy(false) + h = firstPolicy.Select(pool, req) + if h != pool[1] { + t.Error("Expected first policy host to be the second host.") + } +} + +func TestURIHashPolicy(t *testing.T) { + pool := testPool() + uriPolicy := new(URIHashSelection) + + request := httptest.NewRequest(http.MethodGet, "/test", nil) + h := uriPolicy.Select(pool, request) + if h != pool[0] { + t.Error("Expected uri policy host to be the first host.") + } + + pool[0].SetHealthy(false) + h = uriPolicy.Select(pool, request) + if h != pool[1] { + t.Error("Expected uri policy host to be the first host.") + } + + request = httptest.NewRequest(http.MethodGet, "/test_2", nil) + h = uriPolicy.Select(pool, request) + if h != pool[1] { + t.Error("Expected uri policy host to be the second host.") + } + + // We should be able to resize the host pool and still be able to predict + // where a request will be routed with the same URI's used above + pool = UpstreamPool{ + {Host: new(upstreamHost)}, + {Host: new(upstreamHost)}, + } + + request = httptest.NewRequest(http.MethodGet, "/test", nil) + h = uriPolicy.Select(pool, request) + if h != pool[0] { + t.Error("Expected uri policy host to be the first host.") + } + + pool[0].SetHealthy(false) + h = uriPolicy.Select(pool, request) + if h != pool[1] { + t.Error("Expected uri policy host to be the first host.") + } + + request = httptest.NewRequest(http.MethodGet, "/test_2", nil) + h = uriPolicy.Select(pool, request) + if h != pool[1] { + t.Error("Expected uri policy host to be the second host.") + } + + pool[0].SetHealthy(false) + pool[1].SetHealthy(false) + h = uriPolicy.Select(pool, request) + if h != nil { + t.Error("Expected uri policy policy host to be nil.") + } +} -- cgit v1.2.3