From dd9813c65be3af8e222bda6e63f5eea9232c6eee Mon Sep 17 00:00:00 2001 From: fleandro <3987005+flga@users.noreply.github.com> Date: Wed, 7 Sep 2022 21:13:35 +0100 Subject: caddyhttp: ensure ResponseWriterWrapper and ResponseRecorder use ReadFrom if the underlying response writer implements it. (#5022) Doing so allows for splice/sendfile optimizations when available. Fixes #4731 Co-authored-by: flga Co-authored-by: Matthew Holt --- modules/caddyhttp/responsewriter_test.go | 165 +++++++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 modules/caddyhttp/responsewriter_test.go (limited to 'modules/caddyhttp/responsewriter_test.go') diff --git a/modules/caddyhttp/responsewriter_test.go b/modules/caddyhttp/responsewriter_test.go new file mode 100644 index 0000000..1913932 --- /dev/null +++ b/modules/caddyhttp/responsewriter_test.go @@ -0,0 +1,165 @@ +package caddyhttp + +import ( + "bytes" + "fmt" + "io" + "net/http" + "strings" + "testing" +) + +type responseWriterSpy interface { + http.ResponseWriter + Written() string + CalledReadFrom() bool +} + +var ( + _ responseWriterSpy = (*baseRespWriter)(nil) + _ responseWriterSpy = (*readFromRespWriter)(nil) +) + +// a barebones http.ResponseWriter mock +type baseRespWriter []byte + +func (brw *baseRespWriter) Write(d []byte) (int, error) { + *brw = append(*brw, d...) + return len(d), nil +} +func (brw *baseRespWriter) Header() http.Header { return nil } +func (brw *baseRespWriter) WriteHeader(statusCode int) {} +func (brw *baseRespWriter) Written() string { return string(*brw) } +func (brw *baseRespWriter) CalledReadFrom() bool { return false } + +// an http.ResponseWriter mock that supports ReadFrom +type readFromRespWriter struct { + baseRespWriter + called bool +} + +func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) { + rf.called = true + return io.Copy(&rf.baseRespWriter, r) +} + +func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called } + +func TestResponseWriterWrapperReadFrom(t *testing.T) { + tests := map[string]struct { + responseWriter responseWriterSpy + wantReadFrom bool + }{ + "no ReadFrom": { + responseWriter: &baseRespWriter{}, + wantReadFrom: false, + }, + "has ReadFrom": { + responseWriter: &readFromRespWriter{}, + wantReadFrom: true, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + // what we expect middlewares to do: + type myWrapper struct { + *ResponseWriterWrapper + } + + wrapped := myWrapper{ + ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: tt.responseWriter}, + } + + const srcData = "boo!" + // hides everything but Read, since strings.Reader implements WriteTo it would + // take precedence over our ReadFrom. + src := struct{ io.Reader }{strings.NewReader(srcData)} + + fmt.Println(name) + if _, err := io.Copy(wrapped, src); err != nil { + t.Errorf("Copy() err = %v", err) + } + + if got := tt.responseWriter.Written(); got != srcData { + t.Errorf("data = %q, want %q", got, srcData) + } + + if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom { + if tt.wantReadFrom { + t.Errorf("ReadFrom() should have been called") + } else { + t.Errorf("ReadFrom() should not have been called") + } + } + }) + } +} + +func TestResponseRecorderReadFrom(t *testing.T) { + tests := map[string]struct { + responseWriter responseWriterSpy + shouldBuffer bool + wantReadFrom bool + }{ + "buffered plain": { + responseWriter: &baseRespWriter{}, + shouldBuffer: true, + wantReadFrom: false, + }, + "streamed plain": { + responseWriter: &baseRespWriter{}, + shouldBuffer: false, + wantReadFrom: false, + }, + "buffered ReadFrom": { + responseWriter: &readFromRespWriter{}, + shouldBuffer: true, + wantReadFrom: false, + }, + "streamed ReadFrom": { + responseWriter: &readFromRespWriter{}, + shouldBuffer: false, + wantReadFrom: true, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + var buf bytes.Buffer + + rr := NewResponseRecorder(tt.responseWriter, &buf, func(status int, header http.Header) bool { + return tt.shouldBuffer + }) + + const srcData = "boo!" + // hides everything but Read, since strings.Reader implements WriteTo it would + // take precedence over our ReadFrom. + src := struct{ io.Reader }{strings.NewReader(srcData)} + + if _, err := io.Copy(rr, src); err != nil { + t.Errorf("Copy() err = %v", err) + } + + wantStreamed := srcData + wantBuffered := "" + if tt.shouldBuffer { + wantStreamed = "" + wantBuffered = srcData + } + + if got := tt.responseWriter.Written(); got != wantStreamed { + t.Errorf("streamed data = %q, want %q", got, wantStreamed) + } + if got := buf.String(); got != wantBuffered { + t.Errorf("buffered data = %q, want %q", got, wantBuffered) + } + + if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom { + if tt.wantReadFrom { + t.Errorf("ReadFrom() should have been called") + } else { + t.Errorf("ReadFrom() should not have been called") + } + } + }) + } +} -- cgit v1.2.3