summaryrefslogtreecommitdiff
path: root/modules/caddyhttp/reverseproxy/streaming.go
diff options
context:
space:
mode:
Diffstat (limited to 'modules/caddyhttp/reverseproxy/streaming.go')
-rw-r--r--modules/caddyhttp/reverseproxy/streaming.go89
1 files changed, 81 insertions, 8 deletions
diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go
index 1f5387e..6c1e44c 100644
--- a/modules/caddyhttp/reverseproxy/streaming.go
+++ b/modules/caddyhttp/reverseproxy/streaming.go
@@ -33,19 +33,19 @@ import (
"golang.org/x/net/http/httpguts"
)
-func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) {
+func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header)
// Taken from https://github.com/golang/go/commit/5c489514bc5e61ad9b5b07bd7d8ec65d66a0512a
// We know reqUpType is ASCII, it's checked by the caller.
if !asciiIsPrint(resUpType) {
- h.logger.Debug("backend tried to switch to invalid protocol",
+ logger.Debug("backend tried to switch to invalid protocol",
zap.String("backend_upgrade", resUpType))
return
}
if !asciiEqualFold(reqUpType, resUpType) {
- h.logger.Debug("backend tried to switch to unexpected protocol via Upgrade header",
+ logger.Debug("backend tried to switch to unexpected protocol via Upgrade header",
zap.String("backend_upgrade", resUpType),
zap.String("requested_upgrade", reqUpType))
return
@@ -53,12 +53,12 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
hj, ok := rw.(http.Hijacker)
if !ok {
- h.logger.Sugar().Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)
+ logger.Error("can't switch protocols using non-Hijacker ResponseWriter", zap.String("type", fmt.Sprintf("%T", rw)))
return
}
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
- h.logger.Error("internal error: 101 switching protocols response with non-writable body")
+ logger.Error("internal error: 101 switching protocols response with non-writable body")
return
}
@@ -83,7 +83,7 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
logger.Debug("upgrading connection")
conn, brw, err := hj.Hijack()
if err != nil {
- h.logger.Error("hijack failed on protocol switch", zap.Error(err))
+ logger.Error("hijack failed on protocol switch", zap.Error(err))
return
}
@@ -94,7 +94,7 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
}()
if err := brw.Flush(); err != nil {
- h.logger.Debug("response flush", zap.Error(err))
+ logger.Debug("response flush", zap.Error(err))
return
}
@@ -120,10 +120,23 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
spc := switchProtocolCopier{user: conn, backend: backConn}
+ // setup the timeout if requested
+ var timeoutc <-chan time.Time
+ if h.StreamTimeout > 0 {
+ timer := time.NewTimer(time.Duration(h.StreamTimeout))
+ defer timer.Stop()
+ timeoutc = timer.C
+ }
+
errc := make(chan error, 1)
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
- <-errc
+ select {
+ case err := <-errc:
+ logger.Debug("streaming error", zap.Error(err))
+ case time := <-timeoutc:
+ logger.Debug("stream timed out", zap.Time("timeout", time))
+ }
}
// flushInterval returns the p.FlushInterval value, conditionally
@@ -243,10 +256,70 @@ func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func
return func() {
h.connectionsMu.Lock()
delete(h.connections, conn)
+ // if there is no connection left before the connections close timer fires
+ if len(h.connections) == 0 && h.connectionsCloseTimer != nil {
+ // we release the timer that holds the reference to Handler
+ if (*h.connectionsCloseTimer).Stop() {
+ h.logger.Debug("stopped streaming connections close timer - all connections are already closed")
+ }
+ h.connectionsCloseTimer = nil
+ }
h.connectionsMu.Unlock()
}
}
+// closeConnections immediately closes all hijacked connections (both to client and backend).
+func (h *Handler) closeConnections() error {
+ var err error
+ h.connectionsMu.Lock()
+ defer h.connectionsMu.Unlock()
+
+ for _, oc := range h.connections {
+ if oc.gracefulClose != nil {
+ // this is potentially blocking while we have the lock on the connections
+ // map, but that should be OK since the server has in theory shut down
+ // and we are no longer using the connections map
+ gracefulErr := oc.gracefulClose()
+ if gracefulErr != nil && err == nil {
+ err = gracefulErr
+ }
+ }
+ closeErr := oc.conn.Close()
+ if closeErr != nil && err == nil {
+ err = closeErr
+ }
+ }
+ return err
+}
+
+// cleanupConnections closes hijacked connections.
+// Depending on the value of StreamCloseDelay it does that either immediately
+// or sets up a timer that will do that later.
+func (h *Handler) cleanupConnections() error {
+ if h.StreamCloseDelay == 0 {
+ return h.closeConnections()
+ }
+
+ h.connectionsMu.Lock()
+ defer h.connectionsMu.Unlock()
+ // the handler is shut down, no new connection can appear,
+ // so we can skip setting up the timer when there are no connections
+ if len(h.connections) > 0 {
+ delay := time.Duration(h.StreamCloseDelay)
+ h.connectionsCloseTimer = time.AfterFunc(delay, func() {
+ h.logger.Debug("closing streaming connections after delay",
+ zap.Duration("delay", delay))
+ err := h.closeConnections()
+ if err != nil {
+ h.logger.Error("failed to closed connections after delay",
+ zap.Error(err),
+ zap.Duration("delay", delay))
+ }
+ })
+ }
+ return nil
+}
+
// writeCloseControl sends a best-effort Close control message to the given
// WebSocket connection. Thanks to @pascaldekloe who provided inspiration
// from his simple implementation of this I was able to learn from at: