summaryrefslogtreecommitdiff
path: root/modules/caddyhttp/responsewriter.go
blob: 37c264634c88a3fed2cbbb4bdaaf67de5e73ff41 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package caddyhttp

import (
	"bufio"
	"bytes"
	"fmt"
	"io"
	"net"
	"net/http"
)

// ResponseWriterWrapper wraps an underlying ResponseWriter and
// promotes its Pusher method as well. To use this type, embed
// a pointer to it within your own struct type that implements
// the http.ResponseWriter interface, then call methods on the
// embedded value.
type ResponseWriterWrapper struct {
	http.ResponseWriter
}

// Push implements http.Pusher. It simply calls the underlying
// ResponseWriter's Push method if there is one, or returns
// ErrNotImplemented otherwise.
func (rww *ResponseWriterWrapper) Push(target string, opts *http.PushOptions) error {
	if pusher, ok := rww.ResponseWriter.(http.Pusher); ok {
		return pusher.Push(target, opts)
	}
	return ErrNotImplemented
}

// ReadFrom implements io.ReaderFrom. It simply calls io.Copy,
// which uses io.ReaderFrom if available.
func (rww *ResponseWriterWrapper) ReadFrom(r io.Reader) (n int64, err error) {
	return io.Copy(rww.ResponseWriter, r)
}

// Unwrap returns the underlying ResponseWriter, necessary for
// http.ResponseController to work correctly.
func (rww *ResponseWriterWrapper) Unwrap() http.ResponseWriter {
	return rww.ResponseWriter
}

// ErrNotImplemented is returned when an underlying
// ResponseWriter does not implement the required method.
var ErrNotImplemented = fmt.Errorf("method not implemented")

type responseRecorder struct {
	*ResponseWriterWrapper
	statusCode   int
	buf          *bytes.Buffer
	shouldBuffer ShouldBufferFunc
	size         int
	wroteHeader  bool
	stream       bool
}

// NewResponseRecorder returns a new ResponseRecorder that can be
// used instead of a standard http.ResponseWriter. The recorder is
// useful for middlewares which need to buffer a response and
// potentially process its entire body before actually writing the
// response to the underlying writer. Of course, buffering the entire
// body has a memory overhead, but sometimes there is no way to avoid
// buffering the whole response, hence the existence of this type.
// Still, if at all practical, handlers should strive to stream
// responses by wrapping Write and WriteHeader methods instead of
// buffering whole response bodies.
//
// Buffering is actually optional. The shouldBuffer function will
// be called just before the headers are written. If it returns
// true, the headers and body will be buffered by this recorder
// and not written to the underlying writer; if false, the headers
// will be written immediately and the body will be streamed out
// directly to the underlying writer. If shouldBuffer is nil,
// the response will never be buffered and will always be streamed
// directly to the writer.
//
// You can know if shouldBuffer returned true by calling Buffered().
//
// The provided buffer buf should be obtained from a pool for best
// performance (see the sync.Pool type).
//
// Proper usage of a recorder looks like this:
//
//	rec := caddyhttp.NewResponseRecorder(w, buf, shouldBuffer)
//	err := next.ServeHTTP(rec, req)
//	if err != nil {
//	    return err
//	}
//	if !rec.Buffered() {
//	    return nil
//	}
//	// process the buffered response here
//
// The header map is not buffered; i.e. the ResponseRecorder's Header()
// method returns the same header map of the underlying ResponseWriter.
// This is a crucial design decision to allow HTTP trailers to be
// flushed properly (https://github.com/caddyserver/caddy/issues/3236).
//
// Once you are ready to write the response, there are two ways you can
// do it. The easier way is to have the recorder do it:
//
//	rec.WriteResponse()
//
// This writes the recorded response headers as well as the buffered body.
// Or, you may wish to do it yourself, especially if you manipulated the
// buffered body. First you will need to write the headers with the
// recorded status code, then write the body (this example writes the
// recorder's body buffer, but you might have your own body to write
// instead):
//
//	w.WriteHeader(rec.Status())
//	io.Copy(w, rec.Buffer())
//
// As a special case, 1xx responses are not buffered nor recorded
// because they are not the final response; they are passed through
// directly to the underlying ResponseWriter.
func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer ShouldBufferFunc) ResponseRecorder {
	return &responseRecorder{
		ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: w},
		buf:                   buf,
		shouldBuffer:          shouldBuffer,
	}
}

// WriteHeader writes the headers with statusCode to the wrapped
// ResponseWriter unless the response is to be buffered instead.
// 1xx responses are never buffered.
func (rr *responseRecorder) WriteHeader(statusCode int) {
	if rr.wroteHeader {
		return
	}

	// save statusCode always, in case HTTP middleware upgrades websocket
	// connections by manually setting headers and writing status 101
	rr.statusCode = statusCode

	// 1xx responses aren't final; just informational
	if statusCode < 100 || statusCode > 199 {
		rr.wroteHeader = true

		// decide whether we should buffer the response
		if rr.shouldBuffer == nil {
			rr.stream = true
		} else {
			rr.stream = !rr.shouldBuffer(rr.statusCode, rr.ResponseWriterWrapper.Header())
		}
	}

	// if informational or not buffered, immediately write header
	if rr.stream || (100 <= statusCode && statusCode <= 199) {
		rr.ResponseWriterWrapper.WriteHeader(statusCode)
	}
}

func (rr *responseRecorder) Write(data []byte) (int, error) {
	rr.WriteHeader(http.StatusOK)
	var n int
	var err error
	if rr.stream {
		n, err = rr.ResponseWriterWrapper.Write(data)
	} else {
		n, err = rr.buf.Write(data)
	}

	rr.size += n
	return n, err
}

func (rr *responseRecorder) ReadFrom(r io.Reader) (int64, error) {
	rr.WriteHeader(http.StatusOK)
	var n int64
	var err error
	if rr.stream {
		n, err = rr.ResponseWriterWrapper.ReadFrom(r)
	} else {
		n, err = rr.buf.ReadFrom(r)
	}

	rr.size += int(n)
	return n, err
}

// Status returns the status code that was written, if any.
func (rr *responseRecorder) Status() int {
	return rr.statusCode
}

// Size returns the number of bytes written,
// not including the response headers.
func (rr *responseRecorder) Size() int {
	return rr.size
}

// Buffer returns the body buffer that rr was created with.
// You should still have your original pointer, though.
func (rr *responseRecorder) Buffer() *bytes.Buffer {
	return rr.buf
}

// Buffered returns whether rr has decided to buffer the response.
func (rr *responseRecorder) Buffered() bool {
	return !rr.stream
}

func (rr *responseRecorder) WriteResponse() error {
	if rr.stream {
		return nil
	}
	if rr.statusCode == 0 {
		// could happen if no handlers actually wrote anything,
		// and this prevents a panic; status must be > 0
		rr.statusCode = http.StatusOK
	}
	rr.ResponseWriterWrapper.WriteHeader(rr.statusCode)
	_, err := io.Copy(rr.ResponseWriterWrapper, rr.buf)
	return err
}

func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
	//nolint:bodyclose
	conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack()
	if err != nil {
		return nil, nil, err
	}
	// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not
	conn = &hijackedConn{conn, rr}
	brw.Writer.Reset(conn)
	return conn, brw, nil
}

// used to track the size of hijacked response writers
type hijackedConn struct {
	net.Conn
	rr *responseRecorder
}

func (hc *hijackedConn) Write(p []byte) (int, error) {
	n, err := hc.Conn.Write(p)
	hc.rr.size += n
	return n, err
}

func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) {
	n, err := io.Copy(hc.Conn, r)
	hc.rr.size += int(n)
	return n, err
}

// ResponseRecorder is a http.ResponseWriter that records
// responses instead of writing them to the client. See
// docs for NewResponseRecorder for proper usage.
type ResponseRecorder interface {
	http.ResponseWriter
	Status() int
	Buffer() *bytes.Buffer
	Buffered() bool
	Size() int
	WriteResponse() error
}

// ShouldBufferFunc is a function that returns true if the
// response should be buffered, given the pending HTTP status
// code and response headers.
type ShouldBufferFunc func(status int, header http.Header) bool

// Interface guards
var (
	_ http.ResponseWriter = (*ResponseWriterWrapper)(nil)
	_ ResponseRecorder    = (*responseRecorder)(nil)

	// Implementing ReaderFrom can be such a significant
	// optimization that it should probably be required!
	// see PR #5022 (25%-50% speedup)
	_ io.ReaderFrom = (*ResponseWriterWrapper)(nil)
	_ io.ReaderFrom = (*responseRecorder)(nil)
	_ io.ReaderFrom = (*hijackedConn)(nil)
)