summaryrefslogtreecommitdiff
path: root/modules/caddyhttp/reverseproxy/streaming.go
diff options
context:
space:
mode:
Diffstat (limited to 'modules/caddyhttp/reverseproxy/streaming.go')
-rw-r--r--modules/caddyhttp/reverseproxy/streaming.go205
1 files changed, 157 insertions, 48 deletions
diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go
index 1db107a..155a1df 100644
--- a/modules/caddyhttp/reverseproxy/streaming.go
+++ b/modules/caddyhttp/reverseproxy/streaming.go
@@ -20,6 +20,8 @@ package reverseproxy
import (
"context"
+ "errors"
+ "fmt"
"io"
weakrand "math/rand"
"mime"
@@ -32,32 +34,46 @@ import (
"golang.org/x/net/http/httpguts"
)
-func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) {
+func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header)
// Taken from https://github.com/golang/go/commit/5c489514bc5e61ad9b5b07bd7d8ec65d66a0512a
// We know reqUpType is ASCII, it's checked by the caller.
if !asciiIsPrint(resUpType) {
- h.logger.Debug("backend tried to switch to invalid protocol",
+ logger.Debug("backend tried to switch to invalid protocol",
zap.String("backend_upgrade", resUpType))
return
}
if !asciiEqualFold(reqUpType, resUpType) {
- h.logger.Debug("backend tried to switch to unexpected protocol via Upgrade header",
+ logger.Debug("backend tried to switch to unexpected protocol via Upgrade header",
zap.String("backend_upgrade", resUpType),
zap.String("requested_upgrade", reqUpType))
return
}
- hj, ok := rw.(http.Hijacker)
+ backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
+ logger.Error("internal error: 101 switching protocols response with non-writable body")
+ return
+ }
+
+ // write header first, response headers should not be counted in size
+ // like the rest of handler chain.
+ copyHeader(rw.Header(), res.Header)
+ rw.WriteHeader(res.StatusCode)
+
+ logger.Debug("upgrading connection")
+
+ //nolint:bodyclose
+ conn, brw, hijackErr := http.NewResponseController(rw).Hijack()
+ if errors.Is(hijackErr, http.ErrNotSupported) {
h.logger.Sugar().Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)
return
}
- backConn, ok := res.Body.(io.ReadWriteCloser)
- if !ok {
- h.logger.Error("internal error: 101 switching protocols response with non-writable body")
+
+ if hijackErr != nil {
+ h.logger.Error("hijack failed on protocol switch", zap.Error(hijackErr))
return
}
@@ -74,18 +90,6 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
}()
defer close(backConnCloseCh)
- // write header first, response headers should not be counted in size
- // like the rest of handler chain.
- copyHeader(rw.Header(), res.Header)
- rw.WriteHeader(res.StatusCode)
-
- logger.Debug("upgrading connection")
- conn, brw, err := hj.Hijack()
- if err != nil {
- h.logger.Error("hijack failed on protocol switch", zap.Error(err))
- return
- }
-
start := time.Now()
defer func() {
conn.Close()
@@ -93,7 +97,7 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
}()
if err := brw.Flush(); err != nil {
- h.logger.Debug("response flush", zap.Error(err))
+ logger.Debug("response flush", zap.Error(err))
return
}
@@ -119,10 +123,23 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
spc := switchProtocolCopier{user: conn, backend: backConn}
+ // setup the timeout if requested
+ var timeoutc <-chan time.Time
+ if h.StreamTimeout > 0 {
+ timer := time.NewTimer(time.Duration(h.StreamTimeout))
+ defer timer.Stop()
+ timeoutc = timer.C
+ }
+
errc := make(chan error, 1)
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
- <-errc
+ select {
+ case err := <-errc:
+ logger.Debug("streaming error", zap.Error(err))
+ case time := <-timeoutc:
+ logger.Debug("stream timed out", zap.Time("timeout", time))
+ }
}
// flushInterval returns the p.FlushInterval value, conditionally
@@ -167,38 +184,58 @@ func (h Handler) isBidirectionalStream(req *http.Request, res *http.Response) bo
(ae == "identity" || ae == "")
}
-func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
+func (h Handler) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration, logger *zap.Logger) error {
+ var w io.Writer = dst
+
if flushInterval != 0 {
- if wf, ok := dst.(writeFlusher); ok {
- mlw := &maxLatencyWriter{
- dst: wf,
- latency: flushInterval,
- }
- defer mlw.stop()
+ var mlwLogger *zap.Logger
+ if h.VerboseLogs {
+ mlwLogger = logger.Named("max_latency_writer")
+ } else {
+ mlwLogger = zap.NewNop()
+ }
+ mlw := &maxLatencyWriter{
+ dst: dst,
+ //nolint:bodyclose
+ flush: http.NewResponseController(dst).Flush,
+ latency: flushInterval,
+ logger: mlwLogger,
+ }
+ defer mlw.stop()
- // set up initial timer so headers get flushed even if body writes are delayed
- mlw.flushPending = true
- mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
+ // set up initial timer so headers get flushed even if body writes are delayed
+ mlw.flushPending = true
+ mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
- dst = mlw
- }
+ w = mlw
}
buf := streamingBufPool.Get().(*[]byte)
defer streamingBufPool.Put(buf)
- _, err := h.copyBuffer(dst, src, *buf)
+
+ var copyLogger *zap.Logger
+ if h.VerboseLogs {
+ copyLogger = logger
+ } else {
+ copyLogger = zap.NewNop()
+ }
+
+ _, err := h.copyBuffer(w, src, *buf, copyLogger)
return err
}
// copyBuffer returns any write errors or non-EOF read errors, and the amount
// of bytes written.
-func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
+func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *zap.Logger) (int64, error) {
if len(buf) == 0 {
buf = make([]byte, defaultBufferSize)
}
var written int64
for {
+ logger.Debug("waiting to read from upstream")
nr, rerr := src.Read(buf)
+ logger := logger.With(zap.Int("read", nr))
+ logger.Debug("read from upstream", zap.Error(rerr))
if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
// TODO: this could be useful to know (indeed, it revealed an error in our
// fastcgi PoC earlier; but it's this single error report here that necessitates
@@ -210,12 +247,17 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, er
h.logger.Error("reading from backend", zap.Error(rerr))
}
if nr > 0 {
+ logger.Debug("writing to downstream")
nw, werr := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
}
+ logger.Debug("wrote to downstream",
+ zap.Int("written", nw),
+ zap.Int64("written_total", written),
+ zap.Error(werr))
if werr != nil {
- return written, werr
+ return written, fmt.Errorf("writing: %w", werr)
}
if nr != nw {
return written, io.ErrShortWrite
@@ -223,9 +265,9 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, er
}
if rerr != nil {
if rerr == io.EOF {
- rerr = nil
+ return written, nil
}
- return written, rerr
+ return written, fmt.Errorf("reading: %w", rerr)
}
}
}
@@ -242,10 +284,70 @@ func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func
return func() {
h.connectionsMu.Lock()
delete(h.connections, conn)
+ // if there is no connection left before the connections close timer fires
+ if len(h.connections) == 0 && h.connectionsCloseTimer != nil {
+ // we release the timer that holds the reference to Handler
+ if (*h.connectionsCloseTimer).Stop() {
+ h.logger.Debug("stopped streaming connections close timer - all connections are already closed")
+ }
+ h.connectionsCloseTimer = nil
+ }
h.connectionsMu.Unlock()
}
}
+// closeConnections immediately closes all hijacked connections (both to client and backend).
+func (h *Handler) closeConnections() error {
+ var err error
+ h.connectionsMu.Lock()
+ defer h.connectionsMu.Unlock()
+
+ for _, oc := range h.connections {
+ if oc.gracefulClose != nil {
+ // this is potentially blocking while we have the lock on the connections
+ // map, but that should be OK since the server has in theory shut down
+ // and we are no longer using the connections map
+ gracefulErr := oc.gracefulClose()
+ if gracefulErr != nil && err == nil {
+ err = gracefulErr
+ }
+ }
+ closeErr := oc.conn.Close()
+ if closeErr != nil && err == nil {
+ err = closeErr
+ }
+ }
+ return err
+}
+
+// cleanupConnections closes hijacked connections.
+// Depending on the value of StreamCloseDelay it does that either immediately
+// or sets up a timer that will do that later.
+func (h *Handler) cleanupConnections() error {
+ if h.StreamCloseDelay == 0 {
+ return h.closeConnections()
+ }
+
+ h.connectionsMu.Lock()
+ defer h.connectionsMu.Unlock()
+ // the handler is shut down, no new connection can appear,
+ // so we can skip setting up the timer when there are no connections
+ if len(h.connections) > 0 {
+ delay := time.Duration(h.StreamCloseDelay)
+ h.connectionsCloseTimer = time.AfterFunc(delay, func() {
+ h.logger.Debug("closing streaming connections after delay",
+ zap.Duration("delay", delay))
+ err := h.closeConnections()
+ if err != nil {
+ h.logger.Error("failed to closed connections after delay",
+ zap.Error(err),
+ zap.Duration("delay", delay))
+ }
+ })
+ }
+ return nil
+}
+
// writeCloseControl sends a best-effort Close control message to the given
// WebSocket connection. Thanks to @pascaldekloe who provided inspiration
// from his simple implementation of this I was able to learn from at:
@@ -365,29 +467,30 @@ type openConnection struct {
gracefulClose func() error
}
-type writeFlusher interface {
- io.Writer
- http.Flusher
-}
-
type maxLatencyWriter struct {
- dst writeFlusher
+ dst io.Writer
+ flush func() error
latency time.Duration // non-zero; negative means to flush immediately
mu sync.Mutex // protects t, flushPending, and dst.Flush
t *time.Timer
flushPending bool
+ logger *zap.Logger
}
func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
n, err = m.dst.Write(p)
+ m.logger.Debug("wrote bytes", zap.Int("n", n), zap.Error(err))
if m.latency < 0 {
- m.dst.Flush()
+ m.logger.Debug("flushing immediately")
+ //nolint:errcheck
+ m.flush()
return
}
if m.flushPending {
+ m.logger.Debug("delayed flush already pending")
return
}
if m.t == nil {
@@ -395,6 +498,7 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
} else {
m.t.Reset(m.latency)
}
+ m.logger.Debug("timer set for delayed flush", zap.Duration("duration", m.latency))
m.flushPending = true
return
}
@@ -403,9 +507,12 @@ func (m *maxLatencyWriter) delayedFlush() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
+ m.logger.Debug("delayed flush is not pending")
return
}
- m.dst.Flush()
+ m.logger.Debug("delayed flush")
+ //nolint:errcheck
+ m.flush()
m.flushPending = false
}
@@ -445,5 +552,7 @@ var streamingBufPool = sync.Pool{
},
}
-const defaultBufferSize = 32 * 1024
-const wordSize = int(unsafe.Sizeof(uintptr(0)))
+const (
+ defaultBufferSize = 32 * 1024
+ wordSize = int(unsafe.Sizeof(uintptr(0)))
+)