summaryrefslogtreecommitdiff
path: root/modules/caddyhttp
diff options
context:
space:
mode:
Diffstat (limited to 'modules/caddyhttp')
-rw-r--r--modules/caddyhttp/headers/headers.go11
-rw-r--r--modules/caddyhttp/matchers.go55
-rw-r--r--modules/caddyhttp/matchers_test.go131
3 files changed, 194 insertions, 3 deletions
diff --git a/modules/caddyhttp/headers/headers.go b/modules/caddyhttp/headers/headers.go
index 4cab5b5..b07a588 100644
--- a/modules/caddyhttp/headers/headers.go
+++ b/modules/caddyhttp/headers/headers.go
@@ -33,16 +33,18 @@ type HeaderOps struct {
// optionally deferred until response time.
type RespHeaderOps struct {
*HeaderOps
- Deferred bool `json:"deferred"`
+ Require *caddyhttp.ResponseMatcher `json:"require,omitempty"`
+ Deferred bool `json:"deferred,omitempty"`
}
func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
repl := r.Context().Value(caddy2.ReplacerCtxKey).(caddy2.Replacer)
apply(h.Request, r.Header, repl)
- if h.Response.Deferred {
+ if h.Response.Deferred || h.Response.Require != nil {
w = &responseWriterWrapper{
ResponseWriterWrapper: &caddyhttp.ResponseWriterWrapper{ResponseWriter: w},
replacer: repl,
+ require: h.Response.Require,
headerOps: h.Response.HeaderOps,
}
} else {
@@ -75,6 +77,7 @@ func apply(ops *HeaderOps, hdr http.Header, repl caddy2.Replacer) {
type responseWriterWrapper struct {
*caddyhttp.ResponseWriterWrapper
replacer caddy2.Replacer
+ require *caddyhttp.ResponseMatcher
headerOps *HeaderOps
wroteHeader bool
}
@@ -91,7 +94,9 @@ func (rww *responseWriterWrapper) WriteHeader(status int) {
return
}
rww.wroteHeader = true
- apply(rww.headerOps, rww.ResponseWriterWrapper.Header(), rww.replacer)
+ if rww.require == nil || rww.require.Match(status, rww.ResponseWriterWrapper.Header()) {
+ apply(rww.headerOps, rww.ResponseWriterWrapper.Header(), rww.replacer)
+ }
rww.ResponseWriterWrapper.WriteHeader(status)
}
diff --git a/modules/caddyhttp/matchers.go b/modules/caddyhttp/matchers.go
index 33300da..0dda205 100644
--- a/modules/caddyhttp/matchers.go
+++ b/modules/caddyhttp/matchers.go
@@ -298,6 +298,61 @@ func (mre *MatchRegexp) Match(input string, repl caddy2.Replacer, scope string)
return true
}
+// ResponseMatcher is a type which can determine if a given response
+// status code and its headers match some criteria.
+type ResponseMatcher struct {
+ // If set, one of these status codes would be required.
+ // A one-digit status can be used to represent all codes
+ // in that class (e.g. 3 for all 3xx codes).
+ StatusCode []int `json:"status_code,omitempty"`
+
+ // If set, each header specified must be one of the specified values.
+ Headers http.Header `json:"headers,omitempty"`
+}
+
+// Match returns true if the given statusCode and hdr match rm.
+func (rm ResponseMatcher) Match(statusCode int, hdr http.Header) bool {
+ if !rm.matchStatusCode(statusCode) {
+ return false
+ }
+ return rm.matchHeaders(hdr)
+}
+
+func (rm ResponseMatcher) matchStatusCode(statusCode int) bool {
+ if rm.StatusCode == nil {
+ return true
+ }
+ for _, code := range rm.StatusCode {
+ if statusCode == code {
+ return true
+ }
+ if code < 100 && statusCode >= code*100 && statusCode < (code+1)*100 {
+ return true
+ }
+ }
+ return false
+}
+
+func (rm ResponseMatcher) matchHeaders(hdr http.Header) bool {
+ for field, allowedFieldVals := range rm.Headers {
+ var match bool
+ actualFieldVals := hdr[textproto.CanonicalMIMEHeaderKey(field)]
+ fieldVals:
+ for _, actualFieldVal := range actualFieldVals {
+ for _, allowedFieldVal := range allowedFieldVals {
+ if actualFieldVal == allowedFieldVal {
+ match = true
+ break fieldVals
+ }
+ }
+ }
+ if !match {
+ return false
+ }
+ }
+ return true
+}
+
var wordRE = regexp.MustCompile(`\w+`)
// Interface guards
diff --git a/modules/caddyhttp/matchers_test.go b/modules/caddyhttp/matchers_test.go
index 5e62a90..1297bc8 100644
--- a/modules/caddyhttp/matchers_test.go
+++ b/modules/caddyhttp/matchers_test.go
@@ -368,3 +368,134 @@ func TestHeaderREMatcher(t *testing.T) {
}
}
}
+
+func TestResponseMatcher(t *testing.T) {
+ for i, tc := range []struct {
+ require ResponseMatcher
+ status int
+ hdr http.Header // make sure these are canonical cased (std lib will do that in a real request)
+ expect bool
+ }{
+ {
+ require: ResponseMatcher{},
+ status: 200,
+ expect: true,
+ },
+ {
+ require: ResponseMatcher{
+ StatusCode: []int{200},
+ },
+ status: 200,
+ expect: true,
+ },
+ {
+ require: ResponseMatcher{
+ StatusCode: []int{2},
+ },
+ status: 200,
+ expect: true,
+ },
+ {
+ require: ResponseMatcher{
+ StatusCode: []int{201},
+ },
+ status: 200,
+ expect: false,
+ },
+ {
+ require: ResponseMatcher{
+ StatusCode: []int{2},
+ },
+ status: 301,
+ expect: false,
+ },
+ {
+ require: ResponseMatcher{
+ StatusCode: []int{3},
+ },
+ status: 301,
+ expect: true,
+ },
+ {
+ require: ResponseMatcher{
+ StatusCode: []int{3},
+ },
+ status: 399,
+ expect: true,
+ },
+ {
+ require: ResponseMatcher{
+ StatusCode: []int{3},
+ },
+ status: 400,
+ expect: false,
+ },
+ {
+ require: ResponseMatcher{
+ StatusCode: []int{3, 4},
+ },
+ status: 400,
+ expect: true,
+ },
+ {
+ require: ResponseMatcher{
+ StatusCode: []int{3, 401},
+ },
+ status: 401,
+ expect: true,
+ },
+ {
+ require: ResponseMatcher{
+ Headers: http.Header{
+ "Foo": []string{"bar"},
+ },
+ },
+ hdr: http.Header{"Foo": []string{"bar"}},
+ expect: true,
+ },
+ {
+ require: ResponseMatcher{
+ Headers: http.Header{
+ "Foo2": []string{"bar"},
+ },
+ },
+ hdr: http.Header{"Foo": []string{"bar"}},
+ expect: false,
+ },
+ {
+ require: ResponseMatcher{
+ Headers: http.Header{
+ "Foo": []string{"bar", "baz"},
+ },
+ },
+ hdr: http.Header{"Foo": []string{"baz"}},
+ expect: true,
+ },
+ {
+ require: ResponseMatcher{
+ Headers: http.Header{
+ "Foo": []string{"bar"},
+ "Foo2": []string{"baz"},
+ },
+ },
+ hdr: http.Header{"Foo": []string{"baz"}},
+ expect: false,
+ },
+ {
+ require: ResponseMatcher{
+ Headers: http.Header{
+ "Foo": []string{"bar"},
+ "Foo2": []string{"baz"},
+ },
+ },
+ hdr: http.Header{"Foo": []string{"bar"}, "Foo2": []string{"baz"}},
+ expect: true,
+ },
+ } {
+ actual := tc.require.Match(tc.status, tc.hdr)
+ if actual != tc.expect {
+ t.Errorf("Test %d %v: Expected %t, got %t for HTTP %d %v", i, tc.require, tc.expect, actual, tc.status, tc.hdr)
+ continue
+ }
+ }
+}