diff options
Diffstat (limited to 'modules/caddyhttp')
-rw-r--r-- | modules/caddyhttp/responsewriter.go | 37 | ||||
-rw-r--r-- | modules/caddyhttp/responsewriter_test.go | 165 |
2 files changed, 200 insertions, 2 deletions
diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go index 374bbfb..9820b41 100644 --- a/modules/caddyhttp/responsewriter.go +++ b/modules/caddyhttp/responsewriter.go @@ -62,6 +62,16 @@ func (rww *ResponseWriterWrapper) Push(target string, opts *http.PushOptions) er return ErrNotImplemented } +// ReadFrom implements io.ReaderFrom. It simply calls the underlying +// ResponseWriter's ReadFrom method if there is one, otherwise it defaults +// to io.Copy. +func (rww *ResponseWriterWrapper) ReadFrom(r io.Reader) (n int64, err error) { + if rf, ok := rww.ResponseWriter.(io.ReaderFrom); ok { + return rf.ReadFrom(r) + } + return io.Copy(rww.ResponseWriter, r) +} + // HTTPInterfaces mix all the interfaces that middleware ResponseWriters need to support. type HTTPInterfaces interface { http.ResponseWriter @@ -188,9 +198,26 @@ func (rr *responseRecorder) Write(data []byte) (int, error) { } else { n, err = rr.buf.Write(data) } - if err == nil { - rr.size += n + + rr.size += n + return n, err +} + +func (rr *responseRecorder) ReadFrom(r io.Reader) (int64, error) { + rr.WriteHeader(http.StatusOK) + var n int64 + var err error + if rr.stream { + if rf, ok := rr.ResponseWriter.(io.ReaderFrom); ok { + n, err = rf.ReadFrom(r) + } else { + n, err = io.Copy(rr.ResponseWriter, r) + } + } else { + n, err = rr.buf.ReadFrom(r) } + + rr.size += int(n) return n, err } @@ -251,4 +278,10 @@ type ShouldBufferFunc func(status int, header http.Header) bool var ( _ HTTPInterfaces = (*ResponseWriterWrapper)(nil) _ ResponseRecorder = (*responseRecorder)(nil) + + // Implementing ReaderFrom can be such a significant + // optimization that it should probably be required! + // see PR #5022 (25%-50% speedup) + _ io.ReaderFrom = (*ResponseWriterWrapper)(nil) + _ io.ReaderFrom = (*responseRecorder)(nil) ) diff --git a/modules/caddyhttp/responsewriter_test.go b/modules/caddyhttp/responsewriter_test.go new file mode 100644 index 0000000..1913932 --- /dev/null +++ b/modules/caddyhttp/responsewriter_test.go @@ -0,0 +1,165 @@ +package caddyhttp + +import ( + "bytes" + "fmt" + "io" + "net/http" + "strings" + "testing" +) + +type responseWriterSpy interface { + http.ResponseWriter + Written() string + CalledReadFrom() bool +} + +var ( + _ responseWriterSpy = (*baseRespWriter)(nil) + _ responseWriterSpy = (*readFromRespWriter)(nil) +) + +// a barebones http.ResponseWriter mock +type baseRespWriter []byte + +func (brw *baseRespWriter) Write(d []byte) (int, error) { + *brw = append(*brw, d...) + return len(d), nil +} +func (brw *baseRespWriter) Header() http.Header { return nil } +func (brw *baseRespWriter) WriteHeader(statusCode int) {} +func (brw *baseRespWriter) Written() string { return string(*brw) } +func (brw *baseRespWriter) CalledReadFrom() bool { return false } + +// an http.ResponseWriter mock that supports ReadFrom +type readFromRespWriter struct { + baseRespWriter + called bool +} + +func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) { + rf.called = true + return io.Copy(&rf.baseRespWriter, r) +} + +func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called } + +func TestResponseWriterWrapperReadFrom(t *testing.T) { + tests := map[string]struct { + responseWriter responseWriterSpy + wantReadFrom bool + }{ + "no ReadFrom": { + responseWriter: &baseRespWriter{}, + wantReadFrom: false, + }, + "has ReadFrom": { + responseWriter: &readFromRespWriter{}, + wantReadFrom: true, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + // what we expect middlewares to do: + type myWrapper struct { + *ResponseWriterWrapper + } + + wrapped := myWrapper{ + ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: tt.responseWriter}, + } + + const srcData = "boo!" + // hides everything but Read, since strings.Reader implements WriteTo it would + // take precedence over our ReadFrom. + src := struct{ io.Reader }{strings.NewReader(srcData)} + + fmt.Println(name) + if _, err := io.Copy(wrapped, src); err != nil { + t.Errorf("Copy() err = %v", err) + } + + if got := tt.responseWriter.Written(); got != srcData { + t.Errorf("data = %q, want %q", got, srcData) + } + + if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom { + if tt.wantReadFrom { + t.Errorf("ReadFrom() should have been called") + } else { + t.Errorf("ReadFrom() should not have been called") + } + } + }) + } +} + +func TestResponseRecorderReadFrom(t *testing.T) { + tests := map[string]struct { + responseWriter responseWriterSpy + shouldBuffer bool + wantReadFrom bool + }{ + "buffered plain": { + responseWriter: &baseRespWriter{}, + shouldBuffer: true, + wantReadFrom: false, + }, + "streamed plain": { + responseWriter: &baseRespWriter{}, + shouldBuffer: false, + wantReadFrom: false, + }, + "buffered ReadFrom": { + responseWriter: &readFromRespWriter{}, + shouldBuffer: true, + wantReadFrom: false, + }, + "streamed ReadFrom": { + responseWriter: &readFromRespWriter{}, + shouldBuffer: false, + wantReadFrom: true, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + var buf bytes.Buffer + + rr := NewResponseRecorder(tt.responseWriter, &buf, func(status int, header http.Header) bool { + return tt.shouldBuffer + }) + + const srcData = "boo!" + // hides everything but Read, since strings.Reader implements WriteTo it would + // take precedence over our ReadFrom. + src := struct{ io.Reader }{strings.NewReader(srcData)} + + if _, err := io.Copy(rr, src); err != nil { + t.Errorf("Copy() err = %v", err) + } + + wantStreamed := srcData + wantBuffered := "" + if tt.shouldBuffer { + wantStreamed = "" + wantBuffered = srcData + } + + if got := tt.responseWriter.Written(); got != wantStreamed { + t.Errorf("streamed data = %q, want %q", got, wantStreamed) + } + if got := buf.String(); got != wantBuffered { + t.Errorf("buffered data = %q, want %q", got, wantBuffered) + } + + if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom { + if tt.wantReadFrom { + t.Errorf("ReadFrom() should have been called") + } else { + t.Errorf("ReadFrom() should not have been called") + } + } + }) + } +} |