diff options
Diffstat (limited to 'modules/caddyhttp/reverseproxy')
-rw-r--r-- | modules/caddyhttp/reverseproxy/reverseproxy.go | 5 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/streaming.go | 60 | ||||
-rw-r--r-- | modules/caddyhttp/reverseproxy/streaming_test.go | 6 |
3 files changed, 38 insertions, 33 deletions
diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index b331c6b..842d75d 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -962,9 +962,8 @@ func (h *Handler) finalizeResponse( // Force chunking if we saw a response trailer. // This prevents net/http from calculating the length for short // bodies and adding a Content-Length. - if fl, ok := rw.(http.Flusher); ok { - fl.Flush() - } + //nolint:bodyclose + http.NewResponseController(rw).Flush() } // total duration spent proxying, including writing response body diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 6c1e44c..3f2489d 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -20,6 +20,7 @@ package reverseproxy import ( "context" + "errors" "fmt" "io" weakrand "math/rand" @@ -51,17 +52,19 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit return } - hj, ok := rw.(http.Hijacker) - if !ok { - logger.Error("can't switch protocols using non-Hijacker ResponseWriter", zap.String("type", fmt.Sprintf("%T", rw))) - return - } backConn, ok := res.Body.(io.ReadWriteCloser) if !ok { logger.Error("internal error: 101 switching protocols response with non-writable body") return } + //nolint:bodyclose + conn, brw, hijackErr := http.NewResponseController(rw).Hijack() + if errors.Is(hijackErr, http.ErrNotSupported) { + h.logger.Sugar().Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw) + return + } + // adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5 backConnCloseCh := make(chan struct{}) go func() { @@ -81,9 +84,8 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit rw.WriteHeader(res.StatusCode) logger.Debug("upgrading connection") - conn, brw, err := hj.Hijack() - if err != nil { - logger.Error("hijack failed on protocol switch", zap.Error(err)) + if hijackErr != nil { + h.logger.Error("hijack failed on protocol switch", zap.Error(hijackErr)) return } @@ -181,26 +183,28 @@ func (h Handler) isBidirectionalStream(req *http.Request, res *http.Response) bo (ae == "identity" || ae == "") } -func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { +func (h Handler) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error { + var w io.Writer = dst + if flushInterval != 0 { - if wf, ok := dst.(writeFlusher); ok { - mlw := &maxLatencyWriter{ - dst: wf, - latency: flushInterval, - } - defer mlw.stop() + mlw := &maxLatencyWriter{ + dst: dst, + //nolint:bodyclose + flush: http.NewResponseController(dst).Flush, + latency: flushInterval, + } + defer mlw.stop() - // set up initial timer so headers get flushed even if body writes are delayed - mlw.flushPending = true - mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) + // set up initial timer so headers get flushed even if body writes are delayed + mlw.flushPending = true + mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) - dst = mlw - } + w = mlw } buf := streamingBufPool.Get().(*[]byte) defer streamingBufPool.Put(buf) - _, err := h.copyBuffer(dst, src, *buf) + _, err := h.copyBuffer(w, src, *buf) return err } @@ -439,13 +443,9 @@ type openConnection struct { gracefulClose func() error } -type writeFlusher interface { - io.Writer - http.Flusher -} - type maxLatencyWriter struct { - dst writeFlusher + dst io.Writer + flush func() error latency time.Duration // non-zero; negative means to flush immediately mu sync.Mutex // protects t, flushPending, and dst.Flush @@ -458,7 +458,8 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { defer m.mu.Unlock() n, err = m.dst.Write(p) if m.latency < 0 { - m.dst.Flush() + //nolint:errcheck + m.flush() return } if m.flushPending { @@ -479,7 +480,8 @@ func (m *maxLatencyWriter) delayedFlush() { if !m.flushPending { // if stop was called but AfterFunc already started this goroutine return } - m.dst.Flush() + //nolint:errcheck + m.flush() m.flushPending = false } diff --git a/modules/caddyhttp/reverseproxy/streaming_test.go b/modules/caddyhttp/reverseproxy/streaming_test.go index 4ed1f1e..919538f 100644 --- a/modules/caddyhttp/reverseproxy/streaming_test.go +++ b/modules/caddyhttp/reverseproxy/streaming_test.go @@ -2,6 +2,7 @@ package reverseproxy import ( "bytes" + "net/http/httptest" "strings" "testing" ) @@ -13,12 +14,15 @@ func TestHandlerCopyResponse(t *testing.T) { strings.Repeat("a", defaultBufferSize), strings.Repeat("123456789 123456789 123456789 12", 3000), } + dst := bytes.NewBuffer(nil) + recorder := httptest.NewRecorder() + recorder.Body = dst for _, d := range testdata { src := bytes.NewBuffer([]byte(d)) dst.Reset() - err := h.copyResponse(dst, src, 0) + err := h.copyResponse(recorder, src, 0) if err != nil { t.Errorf("failed with error: %v", err) } |