From a458544d9f6e6aaf72aeab0454acfa482880d3d6 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Wed, 16 Oct 2019 15:18:02 -0600 Subject: Minor enhancements/fixes to rewrite directive and template virt req's --- modules/caddyhttp/templates/tplcontext.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) (limited to 'modules/caddyhttp/templates') diff --git a/modules/caddyhttp/templates/tplcontext.go b/modules/caddyhttp/templates/tplcontext.go index 40d1370..e3909b2 100644 --- a/modules/caddyhttp/templates/tplcontext.go +++ b/modules/caddyhttp/templates/tplcontext.go @@ -22,6 +22,7 @@ import ( "net" "net/http" "path" + "strconv" "strings" "sync" @@ -79,8 +80,18 @@ func (c templateContext) Include(filename string, args ...interface{}) (template // are NOT escaped, so you should only include trusted resources. // If it is not trusted, be sure to use escaping functions yourself. func (c templateContext) HTTPInclude(uri string) (template.HTML, error) { - if c.Req.Header.Get(recursionPreventionHeader) == "1" { - return "", fmt.Errorf("virtual request cycle") + // prevent virtual request loops by counting how many levels + // deep we are; and if we get too deep, return an error + recursionCount := 1 + if numStr := c.Req.Header.Get(recursionPreventionHeader); numStr != "" { + num, err := strconv.Atoi(numStr) + if err != nil { + return "", fmt.Errorf("parsing %s: %v", recursionPreventionHeader, err) + } + if num >= 3 { + return "", fmt.Errorf("virtual request cycle") + } + recursionCount = num + 1 } buf := bufPool.Get().(*bytes.Buffer) @@ -91,7 +102,10 @@ func (c templateContext) HTTPInclude(uri string) (template.HTML, error) { if err != nil { return "", err } - virtReq.Header.Set(recursionPreventionHeader, "1") + virtReq.Host = c.Req.Host + virtReq.Header = c.Req.Header.Clone() + virtReq.Trailer = c.Req.Trailer.Clone() + virtReq.Header.Set(recursionPreventionHeader, strconv.Itoa(recursionCount)) vrw := &virtualResponseWriter{body: buf, header: make(http.Header)} server := c.Req.Context().Value(caddyhttp.ServerCtxKey).(http.Handler) -- cgit v1.2.3