diff options
Diffstat (limited to 'modules/caddyhttp/reverseproxy')
-rw-r--r-- | modules/caddyhttp/reverseproxy/reverseproxy.go | 30 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/streaming.go | 78 |
2 files changed, 103 insertions, 5 deletions
diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 895682f..dc1458d 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -192,6 +192,10 @@ type Handler struct { // Holds the handle_response Caddyfile tokens while adapting handleResponseSegments []*caddyfile.Dispenser + // Stores upgraded requests (hijacked connections) for proper cleanup + connections map[io.ReadWriteCloser]openConnection + connectionsMu *sync.Mutex + ctx caddy.Context logger *zap.Logger events *caddyevents.App @@ -214,6 +218,8 @@ func (h *Handler) Provision(ctx caddy.Context) error { h.events = eventAppIface.(*caddyevents.App) h.ctx = ctx h.logger = ctx.Logger(h) + h.connections = make(map[io.ReadWriteCloser]openConnection) + h.connectionsMu = new(sync.Mutex) // verify SRV compatibility - TODO: LookupSRV deprecated; will be removed for i, v := range h.Upstreams { @@ -407,16 +413,34 @@ func (h *Handler) Provision(ctx caddy.Context) error { return nil } -// Cleanup cleans up the resources made by h during provisioning. +// Cleanup cleans up the resources made by h. func (h *Handler) Cleanup() error { - // TODO: Close keepalive connections on reload? https://github.com/caddyserver/caddy/pull/2507/files#diff-70219fd88fe3f36834f474ce6537ed26R762 + // close hijacked connections (both to client and backend) + var err error + h.connectionsMu.Lock() + for _, oc := range h.connections { + if oc.gracefulClose != nil { + // this is potentially blocking while we have the lock on the connections + // map, but that should be OK since the server has in theory shut down + // and we are no longer using the connections map + gracefulErr := oc.gracefulClose() + if gracefulErr != nil && err == nil { + err = gracefulErr + } + } + closeErr := oc.conn.Close() + if closeErr != nil && err == nil { + err = closeErr + } + } + h.connectionsMu.Unlock() // remove hosts from our config from the pool for _, upstream := range h.Upstreams { _, _ = hosts.Delete(upstream.String()) } - return nil + return err } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 298ab58..01d865d 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -20,6 +20,7 @@ package reverseproxy import ( "context" + "encoding/binary" "io" "mime" "net/http" @@ -27,6 +28,7 @@ import ( "time" "go.uber.org/zap" + "golang.org/x/net/http/httpguts" ) func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) { @@ -97,8 +99,26 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite return } - errc := make(chan error, 1) + // Ensure the hijacked client connection, and the new connection established + // with the backend, are both closed in the event of a server shutdown. This + // is done by registering them. We also try to gracefully close connections + // we recognize as websockets. + gracefulClose := func(conn io.ReadWriteCloser) func() error { + if isWebsocket(req) { + return func() error { + return writeCloseControl(conn) + } + } + return nil + } + deleteFrontConn := h.registerConnection(conn, gracefulClose(conn)) + deleteBackConn := h.registerConnection(backConn, gracefulClose(backConn)) + defer deleteFrontConn() + defer deleteBackConn() + spc := switchProtocolCopier{user: conn, backend: backConn} + + errc := make(chan error, 1) go spc.copyToBackend(errc) go spc.copyFromBackend(errc) <-errc @@ -209,6 +229,60 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, er } } +// registerConnection holds onto conn so it can be closed in the event +// of a server shutdown. This is useful because hijacked connections or +// connections dialed to backends don't close when server is shut down. +// The caller should call the returned delete() function when the +// connection is done to remove it from memory. +func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error) (del func()) { + h.connectionsMu.Lock() + h.connections[conn] = openConnection{conn, gracefulClose} + h.connectionsMu.Unlock() + return func() { + h.connectionsMu.Lock() + delete(h.connections, conn) + h.connectionsMu.Unlock() + } +} + +// writeCloseControl sends a best-effort Close control message to the given +// WebSocket connection. Thanks to @pascaldekloe who provided inspiration +// from his simple implementation of this I was able to learn from at: +// github.com/pascaldekloe/websocket. +func writeCloseControl(conn io.Writer) error { + // https://github.com/pascaldekloe/websocket/blob/32050af67a5d/websocket.go#L119 + + var reason string // max 123 bytes (control frame payload limit is 125; status code takes 2) + const goingAway uint16 = 1001 + + // TODO: we might need to ensure we are the exclusive writer by this point (io.Copy is stopped)? + var writeBuf [127]byte + const closeMessage = 8 + const finalBit = 1 << 7 + writeBuf[0] = closeMessage | finalBit + writeBuf[1] = byte(len(reason) + 2) + binary.BigEndian.PutUint16(writeBuf[2:4], goingAway) + copy(writeBuf[4:], reason) + + // simply best-effort, but return error for logging purposes + _, err := conn.Write(writeBuf[:4+len(reason)]) + return err +} + +// isWebsocket returns true if r looks to be an upgrade request for WebSockets. +// It is a fairly naive check. +func isWebsocket(r *http.Request) bool { + return httpguts.HeaderValuesContainsToken(r.Header["Connection"], "upgrade") && + httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket") +} + +// openConnection maps an open connection to +// an optional function for graceful close. +type openConnection struct { + conn io.ReadWriteCloser + gracefulClose func() error +} + type writeFlusher interface { io.Writer http.Flusher @@ -265,7 +339,7 @@ func (m *maxLatencyWriter) stop() { // switchProtocolCopier exists so goroutines proxying data back and // forth have nice names in stacks. type switchProtocolCopier struct { - user, backend io.ReadWriter + user, backend io.ReadWriteCloser } func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { |