summaryrefslogtreecommitdiff
path: root/modules/caddyhttp/reverseproxy
diff options
context:
space:
mode:
authorFrancis Lavoie <lavofr@gmail.com>2023-08-02 16:03:26 -0400
committerGitHub <noreply@github.com>2023-08-02 20:03:26 +0000
commitcd486c25d168caf58f4b6fe5d3252df9432901ec (patch)
tree1c444017467ff3339e5321429eb2c82d37fbf414 /modules/caddyhttp/reverseproxy
parente198c605bd68f4b3630e5fa1ae9f7ca5cac1a7d9 (diff)
caddyhttp: Make use of `http.ResponseController` (#5654)
* caddyhttp: Make use of http.ResponseController Also syncs the reverseproxy implementation with stdlib's which now uses ResponseController as well https://github.com/golang/go/commit/2449bbb5e614954ce9e99c8a481ea2ee73d72d61 * Enable full-duplex for HTTP/1.1 * Appease linter * Add warning for builds with Go 1.20, so it's less surprising to users * Improved godoc for EnableFullDuplex, copied text from stdlib * Only wrap in encode if not already wrapped
Diffstat (limited to 'modules/caddyhttp/reverseproxy')
-rw-r--r--modules/caddyhttp/reverseproxy/reverseproxy.go5
-rw-r--r--modules/caddyhttp/reverseproxy/streaming.go60
-rw-r--r--modules/caddyhttp/reverseproxy/streaming_test.go6
3 files changed, 38 insertions, 33 deletions
diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go
index b331c6b..842d75d 100644
--- a/modules/caddyhttp/reverseproxy/reverseproxy.go
+++ b/modules/caddyhttp/reverseproxy/reverseproxy.go
@@ -962,9 +962,8 @@ func (h *Handler) finalizeResponse(
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
- if fl, ok := rw.(http.Flusher); ok {
- fl.Flush()
- }
+ //nolint:bodyclose
+ http.NewResponseController(rw).Flush()
}
// total duration spent proxying, including writing response body
diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go
index 6c1e44c..3f2489d 100644
--- a/modules/caddyhttp/reverseproxy/streaming.go
+++ b/modules/caddyhttp/reverseproxy/streaming.go
@@ -20,6 +20,7 @@ package reverseproxy
import (
"context"
+ "errors"
"fmt"
"io"
weakrand "math/rand"
@@ -51,17 +52,19 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit
return
}
- hj, ok := rw.(http.Hijacker)
- if !ok {
- logger.Error("can't switch protocols using non-Hijacker ResponseWriter", zap.String("type", fmt.Sprintf("%T", rw)))
- return
- }
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
logger.Error("internal error: 101 switching protocols response with non-writable body")
return
}
+ //nolint:bodyclose
+ conn, brw, hijackErr := http.NewResponseController(rw).Hijack()
+ if errors.Is(hijackErr, http.ErrNotSupported) {
+ h.logger.Sugar().Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)
+ return
+ }
+
// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5
backConnCloseCh := make(chan struct{})
go func() {
@@ -81,9 +84,8 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit
rw.WriteHeader(res.StatusCode)
logger.Debug("upgrading connection")
- conn, brw, err := hj.Hijack()
- if err != nil {
- logger.Error("hijack failed on protocol switch", zap.Error(err))
+ if hijackErr != nil {
+ h.logger.Error("hijack failed on protocol switch", zap.Error(hijackErr))
return
}
@@ -181,26 +183,28 @@ func (h Handler) isBidirectionalStream(req *http.Request, res *http.Response) bo
(ae == "identity" || ae == "")
}
-func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
+func (h Handler) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error {
+ var w io.Writer = dst
+
if flushInterval != 0 {
- if wf, ok := dst.(writeFlusher); ok {
- mlw := &maxLatencyWriter{
- dst: wf,
- latency: flushInterval,
- }
- defer mlw.stop()
+ mlw := &maxLatencyWriter{
+ dst: dst,
+ //nolint:bodyclose
+ flush: http.NewResponseController(dst).Flush,
+ latency: flushInterval,
+ }
+ defer mlw.stop()
- // set up initial timer so headers get flushed even if body writes are delayed
- mlw.flushPending = true
- mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
+ // set up initial timer so headers get flushed even if body writes are delayed
+ mlw.flushPending = true
+ mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
- dst = mlw
- }
+ w = mlw
}
buf := streamingBufPool.Get().(*[]byte)
defer streamingBufPool.Put(buf)
- _, err := h.copyBuffer(dst, src, *buf)
+ _, err := h.copyBuffer(w, src, *buf)
return err
}
@@ -439,13 +443,9 @@ type openConnection struct {
gracefulClose func() error
}
-type writeFlusher interface {
- io.Writer
- http.Flusher
-}
-
type maxLatencyWriter struct {
- dst writeFlusher
+ dst io.Writer
+ flush func() error
latency time.Duration // non-zero; negative means to flush immediately
mu sync.Mutex // protects t, flushPending, and dst.Flush
@@ -458,7 +458,8 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
defer m.mu.Unlock()
n, err = m.dst.Write(p)
if m.latency < 0 {
- m.dst.Flush()
+ //nolint:errcheck
+ m.flush()
return
}
if m.flushPending {
@@ -479,7 +480,8 @@ func (m *maxLatencyWriter) delayedFlush() {
if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
return
}
- m.dst.Flush()
+ //nolint:errcheck
+ m.flush()
m.flushPending = false
}
diff --git a/modules/caddyhttp/reverseproxy/streaming_test.go b/modules/caddyhttp/reverseproxy/streaming_test.go
index 4ed1f1e..919538f 100644
--- a/modules/caddyhttp/reverseproxy/streaming_test.go
+++ b/modules/caddyhttp/reverseproxy/streaming_test.go
@@ -2,6 +2,7 @@ package reverseproxy
import (
"bytes"
+ "net/http/httptest"
"strings"
"testing"
)
@@ -13,12 +14,15 @@ func TestHandlerCopyResponse(t *testing.T) {
strings.Repeat("a", defaultBufferSize),
strings.Repeat("123456789 123456789 123456789 12", 3000),
}
+
dst := bytes.NewBuffer(nil)
+ recorder := httptest.NewRecorder()
+ recorder.Body = dst
for _, d := range testdata {
src := bytes.NewBuffer([]byte(d))
dst.Reset()
- err := h.copyResponse(dst, src, 0)
+ err := h.copyResponse(recorder, src, 0)
if err != nil {
t.Errorf("failed with error: %v", err)
}