diff options
Diffstat (limited to 'modules/caddyhttp')
-rw-r--r-- | modules/caddyhttp/headers/headers.go | 11 | ||||
-rw-r--r-- | modules/caddyhttp/matchers.go | 55 | ||||
-rw-r--r-- | modules/caddyhttp/matchers_test.go | 131 |
3 files changed, 194 insertions, 3 deletions
diff --git a/modules/caddyhttp/headers/headers.go b/modules/caddyhttp/headers/headers.go index 4cab5b5..b07a588 100644 --- a/modules/caddyhttp/headers/headers.go +++ b/modules/caddyhttp/headers/headers.go @@ -33,16 +33,18 @@ type HeaderOps struct { // optionally deferred until response time. type RespHeaderOps struct { *HeaderOps - Deferred bool `json:"deferred"` + Require *caddyhttp.ResponseMatcher `json:"require,omitempty"` + Deferred bool `json:"deferred,omitempty"` } func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { repl := r.Context().Value(caddy2.ReplacerCtxKey).(caddy2.Replacer) apply(h.Request, r.Header, repl) - if h.Response.Deferred { + if h.Response.Deferred || h.Response.Require != nil { w = &responseWriterWrapper{ ResponseWriterWrapper: &caddyhttp.ResponseWriterWrapper{ResponseWriter: w}, replacer: repl, + require: h.Response.Require, headerOps: h.Response.HeaderOps, } } else { @@ -75,6 +77,7 @@ func apply(ops *HeaderOps, hdr http.Header, repl caddy2.Replacer) { type responseWriterWrapper struct { *caddyhttp.ResponseWriterWrapper replacer caddy2.Replacer + require *caddyhttp.ResponseMatcher headerOps *HeaderOps wroteHeader bool } @@ -91,7 +94,9 @@ func (rww *responseWriterWrapper) WriteHeader(status int) { return } rww.wroteHeader = true - apply(rww.headerOps, rww.ResponseWriterWrapper.Header(), rww.replacer) + if rww.require == nil || rww.require.Match(status, rww.ResponseWriterWrapper.Header()) { + apply(rww.headerOps, rww.ResponseWriterWrapper.Header(), rww.replacer) + } rww.ResponseWriterWrapper.WriteHeader(status) } diff --git a/modules/caddyhttp/matchers.go b/modules/caddyhttp/matchers.go index 33300da..0dda205 100644 --- a/modules/caddyhttp/matchers.go +++ b/modules/caddyhttp/matchers.go @@ -298,6 +298,61 @@ func (mre *MatchRegexp) Match(input string, repl caddy2.Replacer, scope string) return true } +// ResponseMatcher is a type which can determine if a given response +// status code and its headers match some criteria. +type ResponseMatcher struct { + // If set, one of these status codes would be required. + // A one-digit status can be used to represent all codes + // in that class (e.g. 3 for all 3xx codes). + StatusCode []int `json:"status_code,omitempty"` + + // If set, each header specified must be one of the specified values. + Headers http.Header `json:"headers,omitempty"` +} + +// Match returns true if the given statusCode and hdr match rm. +func (rm ResponseMatcher) Match(statusCode int, hdr http.Header) bool { + if !rm.matchStatusCode(statusCode) { + return false + } + return rm.matchHeaders(hdr) +} + +func (rm ResponseMatcher) matchStatusCode(statusCode int) bool { + if rm.StatusCode == nil { + return true + } + for _, code := range rm.StatusCode { + if statusCode == code { + return true + } + if code < 100 && statusCode >= code*100 && statusCode < (code+1)*100 { + return true + } + } + return false +} + +func (rm ResponseMatcher) matchHeaders(hdr http.Header) bool { + for field, allowedFieldVals := range rm.Headers { + var match bool + actualFieldVals := hdr[textproto.CanonicalMIMEHeaderKey(field)] + fieldVals: + for _, actualFieldVal := range actualFieldVals { + for _, allowedFieldVal := range allowedFieldVals { + if actualFieldVal == allowedFieldVal { + match = true + break fieldVals + } + } + } + if !match { + return false + } + } + return true +} + var wordRE = regexp.MustCompile(`\w+`) // Interface guards diff --git a/modules/caddyhttp/matchers_test.go b/modules/caddyhttp/matchers_test.go index 5e62a90..1297bc8 100644 --- a/modules/caddyhttp/matchers_test.go +++ b/modules/caddyhttp/matchers_test.go @@ -368,3 +368,134 @@ func TestHeaderREMatcher(t *testing.T) { } } } + +func TestResponseMatcher(t *testing.T) { + for i, tc := range []struct { + require ResponseMatcher + status int + hdr http.Header // make sure these are canonical cased (std lib will do that in a real request) + expect bool + }{ + { + require: ResponseMatcher{}, + status: 200, + expect: true, + }, + { + require: ResponseMatcher{ + StatusCode: []int{200}, + }, + status: 200, + expect: true, + }, + { + require: ResponseMatcher{ + StatusCode: []int{2}, + }, + status: 200, + expect: true, + }, + { + require: ResponseMatcher{ + StatusCode: []int{201}, + }, + status: 200, + expect: false, + }, + { + require: ResponseMatcher{ + StatusCode: []int{2}, + }, + status: 301, + expect: false, + }, + { + require: ResponseMatcher{ + StatusCode: []int{3}, + }, + status: 301, + expect: true, + }, + { + require: ResponseMatcher{ + StatusCode: []int{3}, + }, + status: 399, + expect: true, + }, + { + require: ResponseMatcher{ + StatusCode: []int{3}, + }, + status: 400, + expect: false, + }, + { + require: ResponseMatcher{ + StatusCode: []int{3, 4}, + }, + status: 400, + expect: true, + }, + { + require: ResponseMatcher{ + StatusCode: []int{3, 401}, + }, + status: 401, + expect: true, + }, + { + require: ResponseMatcher{ + Headers: http.Header{ + "Foo": []string{"bar"}, + }, + }, + hdr: http.Header{"Foo": []string{"bar"}}, + expect: true, + }, + { + require: ResponseMatcher{ + Headers: http.Header{ + "Foo2": []string{"bar"}, + }, + }, + hdr: http.Header{"Foo": []string{"bar"}}, + expect: false, + }, + { + require: ResponseMatcher{ + Headers: http.Header{ + "Foo": []string{"bar", "baz"}, + }, + }, + hdr: http.Header{"Foo": []string{"baz"}}, + expect: true, + }, + { + require: ResponseMatcher{ + Headers: http.Header{ + "Foo": []string{"bar"}, + "Foo2": []string{"baz"}, + }, + }, + hdr: http.Header{"Foo": []string{"baz"}}, + expect: false, + }, + { + require: ResponseMatcher{ + Headers: http.Header{ + "Foo": []string{"bar"}, + "Foo2": []string{"baz"}, + }, + }, + hdr: http.Header{"Foo": []string{"bar"}, "Foo2": []string{"baz"}}, + expect: true, + }, + } { + actual := tc.require.Match(tc.status, tc.hdr) + if actual != tc.expect { + t.Errorf("Test %d %v: Expected %t, got %t for HTTP %d %v", i, tc.require, tc.expect, actual, tc.status, tc.hdr) + continue + } + } +} |