diff options
author | Matthew Holt <mholt@users.noreply.github.com> | 2019-10-16 15:18:02 -0600 |
---|---|---|
committer | Matthew Holt <mholt@users.noreply.github.com> | 2019-10-16 15:18:02 -0600 |
commit | a458544d9f6e6aaf72aeab0454acfa482880d3d6 (patch) | |
tree | 7fa6b1806809ede3f81e9949ec819864f294526a /modules | |
parent | 2f91b44587fe487b6a05c1c7e56833247e4b8c79 (diff) |
Minor enhancements/fixes to rewrite directive and template virt req's
Diffstat (limited to 'modules')
-rw-r--r-- | modules/caddyhttp/rewrite/caddyfile.go | 3 | ||||
-rw-r--r-- | modules/caddyhttp/templates/tplcontext.go | 20 |
2 files changed, 20 insertions, 3 deletions
diff --git a/modules/caddyhttp/rewrite/caddyfile.go b/modules/caddyhttp/rewrite/caddyfile.go index cb65d43..a1fc874 100644 --- a/modules/caddyhttp/rewrite/caddyfile.go +++ b/modules/caddyhttp/rewrite/caddyfile.go @@ -31,6 +31,9 @@ func init() { func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) { var rewr Rewrite for h.Next() { + if !h.NextArg() { + return nil, h.ArgErr() + } rewr.URI = h.Val() if h.NextArg() { return nil, h.ArgErr() diff --git a/modules/caddyhttp/templates/tplcontext.go b/modules/caddyhttp/templates/tplcontext.go index 40d1370..e3909b2 100644 --- a/modules/caddyhttp/templates/tplcontext.go +++ b/modules/caddyhttp/templates/tplcontext.go @@ -22,6 +22,7 @@ import ( "net" "net/http" "path" + "strconv" "strings" "sync" @@ -79,8 +80,18 @@ func (c templateContext) Include(filename string, args ...interface{}) (template // are NOT escaped, so you should only include trusted resources. // If it is not trusted, be sure to use escaping functions yourself. func (c templateContext) HTTPInclude(uri string) (template.HTML, error) { - if c.Req.Header.Get(recursionPreventionHeader) == "1" { - return "", fmt.Errorf("virtual request cycle") + // prevent virtual request loops by counting how many levels + // deep we are; and if we get too deep, return an error + recursionCount := 1 + if numStr := c.Req.Header.Get(recursionPreventionHeader); numStr != "" { + num, err := strconv.Atoi(numStr) + if err != nil { + return "", fmt.Errorf("parsing %s: %v", recursionPreventionHeader, err) + } + if num >= 3 { + return "", fmt.Errorf("virtual request cycle") + } + recursionCount = num + 1 } buf := bufPool.Get().(*bytes.Buffer) @@ -91,7 +102,10 @@ func (c templateContext) HTTPInclude(uri string) (template.HTML, error) { if err != nil { return "", err } - virtReq.Header.Set(recursionPreventionHeader, "1") + virtReq.Host = c.Req.Host + virtReq.Header = c.Req.Header.Clone() + virtReq.Trailer = c.Req.Trailer.Clone() + virtReq.Header.Set(recursionPreventionHeader, strconv.Itoa(recursionCount)) vrw := &virtualResponseWriter{body: buf, header: make(http.Header)} server := c.Req.Context().Value(caddyhttp.ServerCtxKey).(http.Handler) |