From 65195a726d9ceff4bbf870b7baa7eff20cf35381 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Mon, 20 May 2019 23:48:43 -0600 Subject: Implement rewrite middleware; fix middleware stack bugs --- modules/caddyhttp/routes.go | 49 ++++++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 16 deletions(-) (limited to 'modules/caddyhttp/routes.go') diff --git a/modules/caddyhttp/routes.go b/modules/caddyhttp/routes.go index daae080..92aa3e8 100644 --- a/modules/caddyhttp/routes.go +++ b/modules/caddyhttp/routes.go @@ -65,23 +65,24 @@ func (routes RouteList) Provision(ctx caddy2.Context) error { return nil } -// BuildCompositeRoute creates a chain of handlers by -// applying all the matching routes. -func (routes RouteList) BuildCompositeRoute(w http.ResponseWriter, r *http.Request) Handler { +// BuildCompositeRoute creates a chain of handlers by applying all the matching +// routes. The returned ResponseWriter should be used instead of rw. +func (routes RouteList) BuildCompositeRoute(rw http.ResponseWriter, req *http.Request) (Handler, http.ResponseWriter) { + mrw := &middlewareResponseWriter{ResponseWriterWrapper: &ResponseWriterWrapper{rw}} + if len(routes) == 0 { - return emptyHandler + return emptyHandler, mrw } var mid []Middleware var responder Handler - mrw := &middlewareResponseWriter{ResponseWriterWrapper: &ResponseWriterWrapper{w}} groups := make(map[string]struct{}) routeLoop: for _, route := range routes { // see if route matches for _, m := range route.matchers { - if !m.Match(r) { + if !m.Match(req) { continue routeLoop } } @@ -102,15 +103,13 @@ routeLoop: // apply the rest of the route for _, m := range route.middleware { - mid = append(mid, func(next HandlerFunc) HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) error { - // TODO: This is where request tracing could be implemented; also - // see below to trace the responder as well - // TODO: Trace a diff of the request, would be cool too! see what changed since the last middleware (host, headers, URI...) - // TODO: see what the std lib gives us in terms of stack trracing too - return m.ServeHTTP(mrw, r, next) - } - }) + // we have to be sure to wrap m outside + // of our current scope so that the + // reference to this m isn't overwritten + // on the next iteration, leaving only + // the last middleware in the chain as + // the ONLY middleware in the chain! + mid = append(mid, wrapMiddleware(m)) } if responder == nil { responder = route.responder @@ -132,7 +131,25 @@ routeLoop: stack = mid[i](stack) } - return stack + return stack, mrw +} + +// wrapMiddleware wraps m such that it can be correctly +// appended to a list of middleware. This is necessary +// so that only the last middleware in a loop does not +// become the only middleware of the stack, repeatedly +// executed (i.e. it is necessary to keep a reference +// to this m outside of the scope of a loop)! +func wrapMiddleware(m MiddlewareHandler) Middleware { + return func(next HandlerFunc) HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { + // TODO: This is where request tracing could be implemented; also + // see below to trace the responder as well + // TODO: Trace a diff of the request, would be cool too! see what changed since the last middleware (host, headers, URI...) + // TODO: see what the std lib gives us in terms of stack tracing too + return m.ServeHTTP(w, r, next) + } + } } type middlewareResponseWriter struct { -- cgit v1.2.3